mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-06 08:05:49 +00:00
Compare commits
19 Commits
nikg/std-e
...
arg_packet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
037540225c | ||
|
|
1be71fd2af | ||
|
|
8bef7ab8fb | ||
|
|
8ca7c1af5a | ||
|
|
089fc6ed3e | ||
|
|
a25079ac23 | ||
|
|
00dca7c3ec | ||
|
|
a8c7b322cb | ||
|
|
e549944bce | ||
|
|
723637f379 | ||
|
|
e9913876c0 | ||
|
|
20948e2ea3 | ||
|
|
5dcbc91643 | ||
|
|
c8e874df49 | ||
|
|
b1a6a08eed | ||
|
|
ec3e571a7f | ||
|
|
b03a0f8cac | ||
|
|
8163ca704a | ||
|
|
b7abf3991a |
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
@@ -485,6 +485,21 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
|
||||
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
@@ -152,9 +153,12 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -39,13 +39,9 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class SlimConnectorExtractionResult(BaseModel):
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector.
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector."""
|
||||
|
||||
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
|
||||
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
|
||||
"""
|
||||
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
doc_ids: set[str]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
@@ -97,37 +93,30 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class BatchResult(BaseModel):
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
def _extract_from_batch(
|
||||
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
|
||||
) -> BatchResult:
|
||||
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
|
||||
) -> tuple[set[str], list[HierarchyNode]]:
|
||||
"""Separate a batch into document IDs and hierarchy nodes.
|
||||
|
||||
ConnectorFailure items have their failed document/entity IDs added to the
|
||||
ID dict so that failed-to-retrieve documents are not accidentally pruned.
|
||||
ID set so that failed-to-retrieve documents are not accidentally pruned.
|
||||
"""
|
||||
ids: dict[str, str | None] = {}
|
||||
ids: set[str] = set()
|
||||
hierarchy_nodes: list[HierarchyNode] = []
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
if item.raw_node_id not in ids:
|
||||
ids[item.raw_node_id] = None
|
||||
ids.add(item.raw_node_id)
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids[failed_id] = None
|
||||
ids.add(failed_id)
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
|
||||
ids[item.id] = parent_raw
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
ids.add(item.id)
|
||||
return ids, hierarchy_nodes
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -143,7 +132,7 @@ def extract_ids_from_runnable_connector(
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_raw_id_to_parent: dict[str, str | None] = {}
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
all_hierarchy_nodes: list[HierarchyNode] = []
|
||||
|
||||
# Sequence (covariant) lets all the specific list[...] iterator types unify here
|
||||
@@ -188,20 +177,15 @@ def extract_ids_from_runnable_connector(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
for k, v in batch_ids.items():
|
||||
if v is not None or k not in all_raw_id_to_parent:
|
||||
all_raw_id_to_parent[k] = v
|
||||
batch_ids, batch_nodes = _extract_from_batch(doc_list)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
doc_ids=all_connector_doc_ids,
|
||||
hierarchy_nodes=all_hierarchy_nodes,
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -48,8 +47,6 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
@@ -60,8 +57,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
@@ -118,38 +113,6 @@ class PruneCallback(IndexingCallbackBase):
|
||||
super().progress(tag, amount)
|
||||
|
||||
|
||||
def _resolve_and_update_document_parents(
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
raw_id_to_parent: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
|
||||
each document and bulk-update the DB. Mirrors the resolution logic in
|
||||
run_docfetching.py."""
|
||||
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
|
||||
|
||||
resolved: dict[str, int | None] = {}
|
||||
for doc_id, raw_parent_id in raw_id_to_parent.items():
|
||||
if raw_parent_id is None:
|
||||
continue
|
||||
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
|
||||
resolved[doc_id] = node_id if found else source_node_id
|
||||
|
||||
if not resolved:
|
||||
return
|
||||
|
||||
update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Pruning: resolved and updated parent hierarchy for "
|
||||
f"{len(resolved)} documents (source={source.value})"
|
||||
)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
@@ -572,22 +535,22 @@ def connector_pruning_generator_task(
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
all_connector_doc_ids = extraction_result.doc_ids
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
ensure_source_node_exists(
|
||||
redis_client, db_session, cc_pair.connector.source
|
||||
)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=source,
|
||||
source=cc_pair.connector.source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
@@ -598,7 +561,7 @@ def connector_pruning_generator_task(
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
source=cc_pair.connector.source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
@@ -607,26 +570,6 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
raw_id_to_parent=all_connector_doc_ids,
|
||||
)
|
||||
|
||||
# Link hierarchy nodes to documents for sources where pages can be
|
||||
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
|
||||
all_doc_id_list = list(all_connector_doc_ids.keys())
|
||||
link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=all_doc_id_list,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
@@ -638,9 +581,7 @@ def connector_pruning_generator_task(
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -1018,6 +1019,7 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_scan_offsets: dict[int, int] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1224,7 +1226,14 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
scan_offsets=arg_scan_offsets,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
236
backend/onyx/chat/tool_call_args_streaming.py
Normal file
236
backend/onyx/chat/tool_call_args_streaming.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import NamedTuple
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
class _Token(NamedTuple):
|
||||
"""A parsed JSON string with position info."""
|
||||
|
||||
value: str # raw content between the quotes
|
||||
start: int # index of first char inside the quotes
|
||||
end: int # index of closing quote, or len(text) if incomplete
|
||||
complete: bool # whether the closing quote was found
|
||||
|
||||
|
||||
def _parse_json_string(text: str, pos: int) -> _Token:
|
||||
"""Parse a JSON string starting at the opening quote at ``pos``."""
|
||||
i = pos + 1
|
||||
while i < len(text):
|
||||
if text[i] == "\\":
|
||||
i += 2
|
||||
elif text[i] == '"':
|
||||
return _Token(text[pos + 1 : i], pos + 1, i, complete=True)
|
||||
else:
|
||||
i += 1
|
||||
return _Token(text[pos + 1 :], pos + 1, len(text), complete=False)
|
||||
|
||||
|
||||
def _skip_json_value(text: str, pos: int) -> int:
|
||||
"""Skip past a non-string JSON value (number, bool, null, array, object).
|
||||
|
||||
Tracks ``[]`` / ``{}`` nesting depth and skips over embedded strings so
|
||||
that internal commas and braces don't terminate the scan early. Stops
|
||||
at the next top-level ``,`` or ``}`` (not consumed).
|
||||
"""
|
||||
depth = 0
|
||||
while pos < len(text):
|
||||
ch = text[pos]
|
||||
if ch == '"':
|
||||
tok = _parse_json_string(text, pos)
|
||||
pos = tok.end + 1 if tok.complete else tok.end
|
||||
continue
|
||||
if ch in ("{", "["):
|
||||
depth += 1
|
||||
elif ch in ("}", "]"):
|
||||
if depth == 0:
|
||||
break
|
||||
depth -= 1
|
||||
elif ch == "," and depth == 0:
|
||||
break
|
||||
pos += 1
|
||||
return pos
|
||||
|
||||
|
||||
def _skip(text: str, pos: int, chars: str = " \t\n\r,") -> int:
|
||||
"""Advance ``pos`` past any characters in ``chars``."""
|
||||
while pos < len(text) and text[pos] in chars:
|
||||
pos += 1
|
||||
return pos
|
||||
|
||||
|
||||
def _decode_partial_json_string(raw: str) -> str:
|
||||
"""Decode JSON escapes (``\\n`` → newline) from a possibly incomplete value.
|
||||
|
||||
Progressively trims up to 6 trailing chars to handle partial escape
|
||||
sequences (the longest JSON escape is ``\\uXXXX``).
|
||||
"""
|
||||
for trim in range(min(7, len(raw) + 1)):
|
||||
candidate = raw[: len(raw) - trim] if trim else raw
|
||||
try:
|
||||
result = json.loads('"' + candidate + '"')
|
||||
if trim > 0 and not result and raw:
|
||||
logger.warning(
|
||||
"Dropped %d chars from partial JSON string value (trim=%d)",
|
||||
len(raw),
|
||||
trim,
|
||||
)
|
||||
return result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
logger.warning(
|
||||
"Failed to decode partial JSON string value; dropping %d chars", len(raw)
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_delta_args(
|
||||
pre: str, delta: str, scan_offset: int = 0
|
||||
) -> tuple[dict[str, str], int]:
|
||||
"""Extract decoded argument values contributed by ``delta``.
|
||||
|
||||
Walks ``pre + delta`` as a partial JSON object (``{"k": "v", ...}``),
|
||||
and for each string value returns only the decoded content that falls
|
||||
within the ``delta`` portion. Escape sequences that straddle the
|
||||
boundary are handled correctly.
|
||||
|
||||
Returns ``(argument_deltas, next_scan_offset)`` where
|
||||
``next_scan_offset`` should be passed to the next call to skip
|
||||
completed key-value pairs, reducing cost from O(accumulated) to
|
||||
O(delta) per call.
|
||||
"""
|
||||
full = pre + delta
|
||||
delta_start = len(pre)
|
||||
|
||||
result: dict[str, str] = {}
|
||||
|
||||
if scan_offset > 0:
|
||||
pos = scan_offset
|
||||
else:
|
||||
pos = full.find("{")
|
||||
if pos == -1:
|
||||
return result, 0
|
||||
pos += 1
|
||||
|
||||
resume = pos
|
||||
|
||||
while pos < len(full):
|
||||
pos = _skip(full, pos)
|
||||
if pos >= len(full) or full[pos] == "}":
|
||||
break
|
||||
|
||||
resume = pos # remember start of this key-value pair
|
||||
|
||||
# Key
|
||||
if full[pos] != '"':
|
||||
break
|
||||
key = _parse_json_string(full, pos)
|
||||
if not key.complete:
|
||||
break
|
||||
pos = key.end + 1
|
||||
|
||||
# Colon
|
||||
pos = _skip(full, pos, " \t\n\r")
|
||||
if pos >= len(full) or full[pos] != ":":
|
||||
break
|
||||
pos += 1
|
||||
|
||||
# Value
|
||||
pos = _skip(full, pos, " \t\n\r")
|
||||
if pos >= len(full):
|
||||
break
|
||||
if full[pos] != '"':
|
||||
# Skip non-string values (number, boolean, null, array, object).
|
||||
# They are available in the final tool-call kickoff packet;
|
||||
# emitting them here as strings would be ambiguous for consumers
|
||||
# (e.g. the number 30 vs the string "30").
|
||||
pos = _skip_json_value(full, pos)
|
||||
continue
|
||||
val = _parse_json_string(full, pos)
|
||||
|
||||
# Only include the portion of this value that overlaps with delta
|
||||
lo = max(val.start, delta_start)
|
||||
hi = val.end
|
||||
if lo < hi:
|
||||
# Decode from value start through both boundaries so escape
|
||||
# sequences straddling the delta edge are handled correctly.
|
||||
decoded_before = _decode_partial_json_string(full[val.start : lo])
|
||||
decoded_through = _decode_partial_json_string(full[val.start : hi])
|
||||
new_content = decoded_through[len(decoded_before) :]
|
||||
if new_content:
|
||||
result[key.value] = new_content
|
||||
|
||||
if not val.complete:
|
||||
break
|
||||
pos = val.end + 1
|
||||
|
||||
return result, resume
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
scan_offsets: dict[int, int],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
NOTE: Currently skips non-string arguments
|
||||
|
||||
``scan_offsets`` is a mutable dict keyed by tool-call index that allows
|
||||
each call to skip past already-processed key-value pairs, reducing
|
||||
per-call cost from O(accumulated) to O(delta).
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.do_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
accumulated_args = tc_data["arguments"]
|
||||
prev_args = accumulated_args[: -len(delta_fragment)]
|
||||
|
||||
idx = tool_call_delta.index
|
||||
offset = scan_offsets.get(idx, 0)
|
||||
|
||||
argument_deltas, new_offset = _extract_delta_args(prev_args, delta_fragment, offset)
|
||||
scan_offsets[idx] = new_offset
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
tool_id=tc_data.get("id", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -943,9 +943,6 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
|
||||
page
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -995,7 +992,6 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=page_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -781,5 +781,4 @@ def build_slim_document(
|
||||
return SlimDocument(
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
|
||||
)
|
||||
|
||||
@@ -902,11 +902,6 @@ class JiraConnector(
|
||||
external_access=self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
),
|
||||
parent_hierarchy_raw_node_id=(
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
|
||||
@@ -385,7 +385,6 @@ class IndexingDocument(Document):
|
||||
class SlimDocument(BaseModel):
|
||||
id: str
|
||||
external_access: ExternalAccess | None = None
|
||||
parent_hierarchy_raw_node_id: str | None = None
|
||||
|
||||
|
||||
class HierarchyNode(BaseModel):
|
||||
|
||||
@@ -772,7 +772,6 @@ def _convert_driveitem_to_slim_document(
|
||||
drive_name: str,
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
) -> SlimDocument:
|
||||
if driveitem.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -788,15 +787,11 @@ def _convert_driveitem_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=driveitem.id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
def _convert_sitepage_to_slim_document(
|
||||
site_page: dict[str, Any],
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
|
||||
) -> SlimDocument:
|
||||
"""Convert a SharePoint site page to a SlimDocument object."""
|
||||
if site_page.get("id") is None:
|
||||
@@ -813,7 +808,6 @@ def _convert_sitepage_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1600,22 +1594,12 @@ class SharepointConnector(
|
||||
)
|
||||
)
|
||||
|
||||
parent_hierarchy_url: str | None = None
|
||||
if drive_web_url:
|
||||
parent_hierarchy_url = self._get_parent_hierarchy_url(
|
||||
site_url, drive_web_url, drive_name, driveitem
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_driveitem_to_slim_document(
|
||||
driveitem,
|
||||
drive_name,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
driveitem, drive_name, ctx, self.graph_client
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1635,10 +1619,7 @@ class SharepointConnector(
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
site_page, ctx, self.graph_client
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
|
||||
@@ -565,7 +565,6 @@ def _get_all_doc_ids(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=channel_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""CRUD operations for HierarchyNode."""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -527,53 +525,6 @@ def get_document_parent_hierarchy_node_ids(
|
||||
return {doc_id: parent_id for doc_id, parent_id in results}
|
||||
|
||||
|
||||
def update_document_parent_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
doc_parent_map: dict[str, int | None],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
|
||||
|
||||
Only updates rows whose current value differs from the desired value to
|
||||
avoid unnecessary writes.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
|
||||
commit: Whether to commit the transaction
|
||||
|
||||
Returns:
|
||||
Number of documents actually updated
|
||||
"""
|
||||
if not doc_parent_map:
|
||||
return 0
|
||||
|
||||
doc_ids = list(doc_parent_map.keys())
|
||||
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
|
||||
|
||||
by_parent: dict[int | None, list[str]] = defaultdict(list)
|
||||
for doc_id, desired_parent_id in doc_parent_map.items():
|
||||
current = existing.get(doc_id)
|
||||
if current == desired_parent_id or doc_id not in existing:
|
||||
continue
|
||||
by_parent[desired_parent_id].append(doc_id)
|
||||
|
||||
updated = 0
|
||||
for desired_parent_id, ids in by_parent.items():
|
||||
db_session.query(Document).filter(Document.id.in_(ids)).update(
|
||||
{Document.parent_hierarchy_node_id: desired_parent_id},
|
||||
synchronize_session=False,
|
||||
)
|
||||
updated += len(ids)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif updated:
|
||||
db_session.flush()
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def update_hierarchy_node_permissions(
|
||||
db_session: Session,
|
||||
raw_node_id: str,
|
||||
|
||||
@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
|
||||
latest_settings = result.scalars().first()
|
||||
|
||||
if not latest_settings:
|
||||
raise RuntimeError("No search settings specified; DB is not in a valid state.")
|
||||
raise RuntimeError("No search settings specified, DB is not in a valid state")
|
||||
return latest_settings
|
||||
|
||||
|
||||
|
||||
@@ -32,6 +32,9 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
|
||||
multipass = should_use_multipass(search_settings)
|
||||
enable_large_chunks = SearchSettings.can_use_large_chunks(
|
||||
multipass, search_settings.model_name, search_settings.provider_type
|
||||
|
||||
@@ -26,10 +26,11 @@ def get_default_document_index(
|
||||
To be used for retrieval only. Indexing should be done through both indices
|
||||
until Vespa is deprecated.
|
||||
|
||||
Pre-existing docstring for this function, although secondary indices are not
|
||||
currently supported:
|
||||
Primary index is the index that is used for querying/updating etc. Secondary
|
||||
index is for when both the currently used index and the upcoming index both
|
||||
need to be updated. Updates are applied to both indices.
|
||||
WARNING: In that case, get_all_document_indices should be used.
|
||||
need to be updated, updates are applied to both indices.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return DisabledDocumentIndex(
|
||||
@@ -50,26 +51,11 @@ def get_default_document_index(
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -100,7 +86,8 @@ def get_all_document_indices(
|
||||
Used for indexing only. Until Vespa is deprecated we will index into both
|
||||
document indices. Retrieval is done through only one index however.
|
||||
|
||||
Large chunks are not currently supported so we hardcode appropriate values.
|
||||
Large chunks and secondary indices are not currently supported so we
|
||||
hardcode appropriate values.
|
||||
|
||||
NOTE: Make sure the Vespa index object is returned first. In the rare event
|
||||
that there is some conflict between indexing and the migration task, it is
|
||||
@@ -136,36 +123,13 @@ def get_all_document_indices(
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
@@ -271,9 +271,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
secondary_embedding_dim: int | None,
|
||||
secondary_embedding_precision: EmbeddingPrecision | None,
|
||||
# NOTE: We do not support large chunks right now.
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
multitenant: bool = False,
|
||||
@@ -289,25 +286,12 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
f"Expected {MULTI_TENANT}, got {multitenant}."
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
self._secondary_real_index: OpenSearchDocumentIndex | None = None
|
||||
if self.secondary_index_name:
|
||||
if secondary_embedding_dim is None or secondary_embedding_precision is None:
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
self._secondary_real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=self.secondary_index_name,
|
||||
embedding_dim=secondary_embedding_dim,
|
||||
embedding_precision=secondary_embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
@@ -323,38 +307,19 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
self,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
secondary_index_embedding_dim: int | None, # noqa: ARG002
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
|
||||
) -> None:
|
||||
self._real_index.verify_and_create_index_if_necessary(
|
||||
# Only handle primary index for now, ignore secondary.
|
||||
return self._real_index.verify_and_create_index_if_necessary(
|
||||
primary_embedding_dim, primary_embedding_precision
|
||||
)
|
||||
if self.secondary_index_name:
|
||||
if (
|
||||
secondary_index_embedding_dim is None
|
||||
or secondary_index_embedding_precision is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.verify_and_create_index_if_necessary(
|
||||
secondary_index_embedding_dim, secondary_index_embedding_precision
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
# Convert IndexBatchParams to IndexingMetadata.
|
||||
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
|
||||
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
|
||||
@@ -386,20 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
total_chunks_deleted += self._secondary_real_index.delete(
|
||||
doc_id, chunk_count
|
||||
)
|
||||
return total_chunks_deleted
|
||||
return self._real_index.delete(doc_id, chunk_count)
|
||||
|
||||
def update_single(
|
||||
self,
|
||||
@@ -410,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> None:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
f"Tried to update document {doc_id} with no updated fields or user fields."
|
||||
@@ -445,11 +392,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
try:
|
||||
self._real_index.update([update_request])
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.update([update_request])
|
||||
except NotFoundError:
|
||||
logger.exception(
|
||||
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "
|
||||
|
||||
@@ -465,12 +465,6 @@ class VespaIndex(DocumentIndex):
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
|
||||
index_batch_params.doc_id_to_new_chunk_cnt
|
||||
):
|
||||
@@ -665,10 +659,6 @@ class VespaIndex(DocumentIndex):
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
@@ -689,6 +679,13 @@ class VespaIndex(DocumentIndex):
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
@@ -708,20 +705,7 @@ class VespaIndex(DocumentIndex):
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
vespa_document_index.update([update_request])
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
def delete_single(
|
||||
self,
|
||||
@@ -730,11 +714,6 @@ class VespaIndex(DocumentIndex):
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -747,25 +726,13 @@ class VespaIndex(DocumentIndex):
|
||||
raise ValueError(
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
total_chunks_deleted = 0
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
total_chunks_deleted += vespa_document_index.delete(
|
||||
document_id=doc_id, chunk_count=chunk_count
|
||||
)
|
||||
|
||||
return total_chunks_deleted
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
|
||||
@@ -3,6 +3,7 @@ from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select
|
||||
@@ -52,8 +53,6 @@ from onyx.db.permission_sync_attempt import (
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
get_recent_doc_permission_sync_attempts_for_cc_pair,
|
||||
)
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_utils import get_deletion_attempt_snapshot
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -88,9 +87,8 @@ def get_cc_pair_index_attempts(
|
||||
cc_pair_id, db_session, user, get_editable=False
|
||||
)
|
||||
if not user_has_access:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"CC Pair not found for current user permissions",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="CC Pair not found for current user permissions"
|
||||
)
|
||||
|
||||
total_count = count_index_attempts_for_cc_pair(
|
||||
@@ -125,9 +123,8 @@ def get_cc_pair_permission_sync_attempts(
|
||||
cc_pair_id, db_session, user, get_editable=False
|
||||
)
|
||||
if not user_has_access:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"CC Pair not found for current user permissions",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="CC Pair not found for current user permissions"
|
||||
)
|
||||
|
||||
# Get all permission sync attempts for this cc pair
|
||||
@@ -163,9 +160,8 @@ def get_cc_pair_full_info(
|
||||
cc_pair_id, db_session, user, get_editable=False
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"CC Pair not found for current user permissions",
|
||||
raise HTTPException(
|
||||
status_code=404, detail="CC Pair not found for current user permissions"
|
||||
)
|
||||
editable_cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id, db_session, user, get_editable=True
|
||||
@@ -268,9 +264,9 @@ def update_cc_pair_status(
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"Connection not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
@@ -343,9 +339,8 @@ def update_cc_pair_name(
|
||||
get_editable=True,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"CC Pair not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="CC Pair not found for current user's permissions"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -356,7 +351,7 @@ def update_cc_pair_name(
|
||||
)
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Name must be unique")
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
|
||||
@router.put("/admin/cc-pair/{cc_pair_id}/property")
|
||||
@@ -373,9 +368,8 @@ def update_cc_pair_property(
|
||||
get_editable=True,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"CC Pair not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="CC Pair not found for current user's permissions"
|
||||
)
|
||||
|
||||
# Can we centralize logic for updating connector properties
|
||||
@@ -393,9 +387,8 @@ def update_cc_pair_property(
|
||||
|
||||
msg = "Pruning frequency updated successfully"
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Property name {update_request.name} is not valid.",
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Property name {update_request.name} is not valid."
|
||||
)
|
||||
|
||||
return StatusResponse(success=True, message=msg, data=cc_pair_id)
|
||||
@@ -414,9 +407,9 @@ def get_cc_pair_last_pruned(
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"cc_pair not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_pruned
|
||||
@@ -438,16 +431,19 @@ def prune_cc_pair(
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"Connection not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.prune.fenced:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Pruning task already in progress.")
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Pruning task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pruning cc_pair: cc_pair={cc_pair_id} "
|
||||
@@ -459,7 +455,10 @@ def prune_cc_pair(
|
||||
client_app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Pruning task creation failed.")
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Pruning task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}")
|
||||
|
||||
@@ -589,21 +588,20 @@ def associate_credential_to_connector(
|
||||
delete_connector(db_session, connector_id)
|
||||
db_session.commit()
|
||||
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONNECTOR_VALIDATION_FAILED,
|
||||
"Connector validation error: " + str(e),
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Connector validation error: " + str(e)
|
||||
)
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
delete_connector(db_session, connector_id)
|
||||
db_session.commit()
|
||||
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Name must be unique")
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Unexpected error")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error")
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
@@ -114,8 +115,6 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.file_types import PLAIN_TEXT_MIME_TYPE
|
||||
from onyx.file_processing.file_types import WORD_PROCESSING_MIME_TYPE
|
||||
from onyx.file_store.file_store import FileStore
|
||||
@@ -181,7 +180,7 @@ def check_google_app_gmail_credentials_exist(
|
||||
try:
|
||||
return {"client_id": get_google_app_cred(DocumentSource.GMAIL).web.client_id}
|
||||
except KvKeyNotFoundError:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Google App Credentials not found")
|
||||
raise HTTPException(status_code=404, detail="Google App Credentials not found")
|
||||
|
||||
|
||||
@router.put("/admin/connector/gmail/app-credential")
|
||||
@@ -191,7 +190,7 @@ def upsert_google_app_gmail_credentials(
|
||||
try:
|
||||
upsert_google_app_cred(app_credentials, DocumentSource.GMAIL)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Successfully saved Google App Credentials"
|
||||
@@ -207,7 +206,7 @@ def delete_google_app_gmail_credentials(
|
||||
delete_google_app_cred(DocumentSource.GMAIL)
|
||||
cleanup_gmail_credentials(db_session=db_session)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Successfully deleted Google App Credentials"
|
||||
@@ -223,7 +222,7 @@ def check_google_app_credentials_exist(
|
||||
"client_id": get_google_app_cred(DocumentSource.GOOGLE_DRIVE).web.client_id
|
||||
}
|
||||
except KvKeyNotFoundError:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Google App Credentials not found")
|
||||
raise HTTPException(status_code=404, detail="Google App Credentials not found")
|
||||
|
||||
|
||||
@router.put("/admin/connector/google-drive/app-credential")
|
||||
@@ -233,7 +232,7 @@ def upsert_google_app_credentials(
|
||||
try:
|
||||
upsert_google_app_cred(app_credentials, DocumentSource.GOOGLE_DRIVE)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Successfully saved Google App Credentials"
|
||||
@@ -249,7 +248,7 @@ def delete_google_app_credentials(
|
||||
delete_google_app_cred(DocumentSource.GOOGLE_DRIVE)
|
||||
cleanup_google_drive_credentials(db_session=db_session)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Successfully deleted Google App Credentials"
|
||||
@@ -267,7 +266,9 @@ def check_google_service_gmail_account_key_exist(
|
||||
).client_email
|
||||
}
|
||||
except KvKeyNotFoundError:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Google Service Account Key not found")
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Google Service Account Key not found"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/connector/gmail/service-account-key")
|
||||
@@ -277,7 +278,7 @@ def upsert_google_service_gmail_account_key(
|
||||
try:
|
||||
upsert_service_account_key(service_account_key, DocumentSource.GMAIL)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Successfully saved Google Service Account Key"
|
||||
@@ -293,7 +294,7 @@ def delete_google_service_gmail_account_key(
|
||||
delete_service_account_key(DocumentSource.GMAIL)
|
||||
cleanup_gmail_credentials(db_session=db_session)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Successfully deleted Google Service Account Key"
|
||||
@@ -311,7 +312,9 @@ def check_google_service_account_key_exist(
|
||||
).client_email
|
||||
}
|
||||
except KvKeyNotFoundError:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Google Service Account Key not found")
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Google Service Account Key not found"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/connector/google-drive/service-account-key")
|
||||
@@ -321,7 +324,7 @@ def upsert_google_service_account_key(
|
||||
try:
|
||||
upsert_service_account_key(service_account_key, DocumentSource.GOOGLE_DRIVE)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Successfully saved Google Service Account Key"
|
||||
@@ -337,7 +340,7 @@ def delete_google_service_account_key(
|
||||
delete_service_account_key(DocumentSource.GOOGLE_DRIVE)
|
||||
cleanup_google_drive_credentials(db_session=db_session)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Successfully deleted Google Service Account Key"
|
||||
@@ -360,7 +363,7 @@ def upsert_service_account_credential(
|
||||
name="Service Account (uploaded)",
|
||||
)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# first delete all existing service account credentials
|
||||
delete_service_account_credentials(user, db_session, DocumentSource.GOOGLE_DRIVE)
|
||||
@@ -386,7 +389,7 @@ def upsert_gmail_service_account_credential(
|
||||
primary_admin_email=service_account_credential_request.google_primary_admin,
|
||||
)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# first delete all existing service account credentials
|
||||
delete_service_account_credentials(user, db_session, DocumentSource.GMAIL)
|
||||
@@ -437,9 +440,9 @@ def save_zip_metadata_to_file_store(
|
||||
json.loads(metadata_bytes)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Unable to load {ONYX_METADATA_FILENAME}: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Unable to load {ONYX_METADATA_FILENAME}: {e}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unable to load {ONYX_METADATA_FILENAME}: {e}",
|
||||
)
|
||||
|
||||
# Save to file store
|
||||
@@ -497,7 +500,7 @@ def upload_files(
|
||||
|
||||
if is_zip_file(file):
|
||||
if seen_zip:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, SEEN_ZIP_DETAIL)
|
||||
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
|
||||
seen_zip = True
|
||||
with zipfile.ZipFile(file.file, "r") as zf:
|
||||
zip_metadata_file_id = save_zip_metadata_to_file_store(
|
||||
@@ -551,7 +554,7 @@ def upload_files(
|
||||
deduped_file_names.append(file.filename)
|
||||
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return FileUploadResponse(
|
||||
file_paths=deduped_file_paths,
|
||||
file_names=deduped_file_names,
|
||||
@@ -578,9 +581,9 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
) -> ConnectorCredentialPair:
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"No Connector-Credential Pair found for this connector",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
has_requested_access = verify_user_has_access_to_cc_pair(
|
||||
@@ -601,9 +604,9 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
):
|
||||
return cc_pair
|
||||
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"Access denied. User cannot manage files for this connector.",
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied. User cannot manage files for this connector.",
|
||||
)
|
||||
|
||||
|
||||
@@ -624,12 +627,11 @@ def list_connector_files(
|
||||
"""List all files in a file connector."""
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector is None:
|
||||
raise OnyxError(OnyxErrorCode.CONNECTOR_NOT_FOUND, "Connector not found")
|
||||
raise HTTPException(status_code=404, detail="Connector not found")
|
||||
|
||||
if connector.source != DocumentSource.FILE:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"This endpoint only works with file connectors",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
_ = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
@@ -698,12 +700,11 @@ def update_connector_files(
|
||||
files = files or []
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector is None:
|
||||
raise OnyxError(OnyxErrorCode.CONNECTOR_NOT_FOUND, "Connector not found")
|
||||
raise HTTPException(status_code=404, detail="Connector not found")
|
||||
|
||||
if connector.source != DocumentSource.FILE:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"This endpoint only works with file connectors",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
# Get the connector-credential pair for indexing/pruning triggers
|
||||
@@ -719,14 +720,12 @@ def update_connector_files(
|
||||
try:
|
||||
file_ids_list = json.loads(file_ids_to_remove)
|
||||
except json.JSONDecodeError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "Invalid file_ids_to_remove format"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Invalid file_ids_to_remove format")
|
||||
|
||||
if not isinstance(file_ids_list, list):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"file_ids_to_remove must be a JSON-encoded list",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="file_ids_to_remove must be a JSON-encoded list",
|
||||
)
|
||||
|
||||
# Get current connector config
|
||||
@@ -751,9 +750,9 @@ def update_connector_files(
|
||||
current_zip_metadata = loaded_metadata
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing metadata file: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to load existing connector metadata file",
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to load existing connector metadata file",
|
||||
)
|
||||
|
||||
# Upload new files if any
|
||||
@@ -808,9 +807,9 @@ def update_connector_files(
|
||||
|
||||
# Validate that at least one file remains
|
||||
if not final_file_locations:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Cannot remove all files from connector. At least one file must remain.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot remove all files from connector. At least one file must remain.",
|
||||
)
|
||||
|
||||
# Merge and filter metadata (remove metadata for deleted files)
|
||||
@@ -853,8 +852,8 @@ def update_connector_files(
|
||||
|
||||
updated_connector = update_connector(connector_id, connector_base, db_session)
|
||||
if updated_connector is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Failed to update connector configuration"
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to update connector configuration"
|
||||
)
|
||||
|
||||
# Trigger re-indexing for new files and pruning for removed files
|
||||
@@ -1542,7 +1541,7 @@ def create_connector_from_model(
|
||||
return connector_response
|
||||
except ValueError as e:
|
||||
logger.error(f"Error creating connector: {e}")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/admin/connector-with-mock-credential")
|
||||
@@ -1620,12 +1619,11 @@ def create_connector_with_mock_credential(
|
||||
return response
|
||||
|
||||
except ConnectorValidationError as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONNECTOR_VALIDATION_FAILED,
|
||||
"Connector validation error: " + str(e),
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Connector validation error: " + str(e)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/admin/connector/{connector_id}", tags=PUBLIC_API_TAGS)
|
||||
@@ -1650,13 +1648,12 @@ def update_connector_from_model(
|
||||
)
|
||||
connector_base = connector_data.to_connector_base()
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
updated_connector = update_connector(connector_id, connector_base, db_session)
|
||||
if updated_connector is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONNECTOR_NOT_FOUND,
|
||||
f"Connector {connector_id} does not exist",
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Connector {connector_id} does not exist"
|
||||
)
|
||||
|
||||
return ConnectorSnapshot(
|
||||
@@ -1693,7 +1690,7 @@ def delete_connector_by_id(
|
||||
connector_id=connector_id,
|
||||
)
|
||||
except AssertionError:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Connector is not deletable")
|
||||
raise HTTPException(status_code=400, detail="Connector is not deletable")
|
||||
|
||||
|
||||
@router.post("/admin/connector/run-once", tags=PUBLIC_API_TAGS)
|
||||
@@ -1714,9 +1711,9 @@ def connector_run_once(
|
||||
run_info.connector_id, db_session
|
||||
)
|
||||
except ValueError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONNECTOR_NOT_FOUND,
|
||||
f"Connector by id {connector_id} does not exist.",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Connector by id {connector_id} does not exist.",
|
||||
)
|
||||
|
||||
if not specified_credential_ids:
|
||||
@@ -1725,15 +1722,15 @@ def connector_run_once(
|
||||
if set(specified_credential_ids).issubset(set(possible_credential_ids)):
|
||||
credential_ids = specified_credential_ids
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Not all specified credentials are associated with connector",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Not all specified credentials are associated with connector",
|
||||
)
|
||||
|
||||
if not credential_ids:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Connector has no valid credentials, cannot create index attempts.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connector has no valid credentials, cannot create index attempts.",
|
||||
)
|
||||
try:
|
||||
num_triggers = trigger_indexing_for_cc_pair(
|
||||
@@ -1744,7 +1741,7 @@ def connector_run_once(
|
||||
db_session,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
logger.info("connector_run_once - running check_for_indexing")
|
||||
|
||||
@@ -1798,8 +1795,8 @@ def gmail_callback(
|
||||
) -> StatusResponse:
|
||||
credential_id_cookie = request.cookies.get(_GMAIL_CREDENTIAL_ID_COOKIE_NAME)
|
||||
if credential_id_cookie is None or not credential_id_cookie.isdigit():
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CSRF_FAILURE, "Request did not pass CSRF verification."
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Request did not pass CSRF verification."
|
||||
)
|
||||
credential_id = int(credential_id_cookie)
|
||||
verify_csrf(credential_id, callback.state)
|
||||
@@ -1812,8 +1809,8 @@ def gmail_callback(
|
||||
GoogleOAuthAuthenticationMethod.UPLOADED,
|
||||
)
|
||||
if credentials is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Unable to fetch Gmail access tokens"
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Unable to fetch Gmail access tokens"
|
||||
)
|
||||
|
||||
return StatusResponse(success=True, message="Updated Gmail access tokens")
|
||||
@@ -1828,8 +1825,8 @@ def google_drive_callback(
|
||||
) -> StatusResponse:
|
||||
credential_id_cookie = request.cookies.get(_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME)
|
||||
if credential_id_cookie is None or not credential_id_cookie.isdigit():
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CSRF_FAILURE, "Request did not pass CSRF verification."
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Request did not pass CSRF verification."
|
||||
)
|
||||
credential_id = int(credential_id_cookie)
|
||||
verify_csrf(credential_id, callback.state)
|
||||
@@ -1843,9 +1840,8 @@ def google_drive_callback(
|
||||
GoogleOAuthAuthenticationMethod.UPLOADED,
|
||||
)
|
||||
if credentials is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Unable to fetch Google Drive access tokens",
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Unable to fetch Google Drive access tokens"
|
||||
)
|
||||
|
||||
return StatusResponse(success=True, message="Updated Google Drive access tokens")
|
||||
@@ -1885,9 +1881,8 @@ def get_connector_by_id(
|
||||
) -> ConnectorSnapshot | StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONNECTOR_NOT_FOUND,
|
||||
f"Connector {connector_id} does not exist",
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Connector {connector_id} does not exist"
|
||||
)
|
||||
|
||||
return ConnectorSnapshot(
|
||||
@@ -1920,9 +1915,7 @@ def submit_connector_request(
|
||||
connector_name = request_data.connector_name.strip()
|
||||
|
||||
if not connector_name:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "Connector name cannot be empty"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
|
||||
|
||||
# Get user identifier for telemetry
|
||||
user_email = user.email if user else None
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -27,8 +28,6 @@ from onyx.db.credentials import update_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.documents.models import CredentialDataUpdateRequest
|
||||
from onyx.server.documents.models import CredentialSnapshot
|
||||
@@ -177,18 +176,18 @@ def create_credential_with_private_key(
|
||||
try:
|
||||
credential_data = json.loads(credential_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid JSON in credential_json: {str(e)}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid JSON in credential_json: {str(e)}",
|
||||
)
|
||||
|
||||
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
|
||||
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
|
||||
)
|
||||
if private_key_processor is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Invalid type definition key for private key file",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid type definition key for private key file",
|
||||
)
|
||||
private_key_content: str = private_key_processor(uploaded_file)
|
||||
|
||||
@@ -252,9 +251,9 @@ def get_credential_by_id(
|
||||
get_editable=False,
|
||||
)
|
||||
if credential is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
|
||||
f"Credential {credential_id} does not exist or does not belong to user",
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
@@ -276,9 +275,9 @@ def update_credential_data(
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
|
||||
f"Credential {credential_id} does not exist or does not belong to user",
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
@@ -298,18 +297,18 @@ def update_credential_private_key(
|
||||
try:
|
||||
credential_data = json.loads(credential_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid JSON in credential_json: {str(e)}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid JSON in credential_json: {str(e)}",
|
||||
)
|
||||
|
||||
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
|
||||
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
|
||||
)
|
||||
if private_key_processor is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Invalid type definition key for private key file",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid type definition key for private key file",
|
||||
)
|
||||
private_key_content: str = private_key_processor(uploaded_file)
|
||||
credential_data[field_key] = private_key_content
|
||||
@@ -323,9 +322,9 @@ def update_credential_private_key(
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
|
||||
f"Credential {credential_id} does not exist or does not belong to user",
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
@@ -342,9 +341,9 @@ def update_credential_from_model(
|
||||
credential_id, credential_data, user, db_session
|
||||
)
|
||||
if updated_credential is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
|
||||
f"Credential {credential_id} does not exist or does not belong to user",
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
# Get credential_json value - use masking for API responses
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -13,8 +14,6 @@ from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.prompt_utils import build_doc_context_str
|
||||
from onyx.server.documents.models import ChunkInfo
|
||||
@@ -44,7 +43,7 @@ def get_document_info(
|
||||
)
|
||||
|
||||
if not inference_chunks:
|
||||
raise OnyxError(OnyxErrorCode.DOCUMENT_NOT_FOUND, "Document not found")
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
contents = [chunk.content for chunk in inference_chunks]
|
||||
|
||||
@@ -96,7 +95,7 @@ def get_chunk_info(
|
||||
)
|
||||
|
||||
if not inference_chunks:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Chunk not found")
|
||||
raise HTTPException(status_code=404, detail="Chunk not found")
|
||||
|
||||
chunk_content = inference_chunks[0].content
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@ import base64
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.documents.document_utils import validate_pkcs12_content
|
||||
|
||||
|
||||
@@ -32,9 +31,8 @@ def process_sharepoint_private_key_file(file: UploadFile) -> str:
|
||||
"""
|
||||
# First check file extension (basic filter)
|
||||
if not (file.filename and file.filename.lower().endswith(".pfx")):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Invalid file type. Only .pfx files are supported.",
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid file type. Only .pfx files are supported."
|
||||
)
|
||||
|
||||
# Read file content for validation and processing
|
||||
@@ -42,9 +40,9 @@ def process_sharepoint_private_key_file(file: UploadFile) -> str:
|
||||
|
||||
# Validate file content to prevent extension spoofing attacks
|
||||
if not validate_pkcs12_content(private_key_bytes):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.",
|
||||
)
|
||||
|
||||
# Convert to base64 if validation passes
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
@@ -18,8 +19,6 @@ from onyx.connectors.interfaces import OAuthConnector
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -70,10 +69,12 @@ def _get_additional_kwargs(
|
||||
# validate
|
||||
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
|
||||
except ValidationError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
|
||||
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
|
||||
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
|
||||
),
|
||||
)
|
||||
|
||||
return additional_kwargs_dict
|
||||
@@ -96,9 +97,7 @@ def oauth_authorize(
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, f"Unknown OAuth source: {source}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
base_url = WEB_DOMAIN
|
||||
@@ -148,9 +147,7 @@ def oauth_callback(
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, f"Unknown OAuth source: {source}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
|
||||
@@ -160,7 +157,7 @@ def oauth_callback(
|
||||
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
||||
)
|
||||
if not oauth_state_bytes:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Invalid OAuth state")
|
||||
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
||||
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))
|
||||
|
||||
desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])
|
||||
|
||||
@@ -6,11 +6,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import resync_cc_pair
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
@@ -18,25 +15,20 @@ from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import update_no_default_contextual_rag_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import create_search_settings
|
||||
from onyx.db.search_settings import delete_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_embedding_provider_from_provider_type
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_processing.unstructured import delete_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import update_unstructured_api_key
|
||||
from onyx.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
|
||||
from onyx.server.manage.models import FullModelVersionResponse
|
||||
from onyx.server.models import IdReturn
|
||||
from onyx.server.utils_vector_db import require_vector_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
router = APIRouter(prefix="/search-settings")
|
||||
@@ -49,99 +41,110 @@ def set_new_search_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session), # noqa: ARG001
|
||||
) -> IdReturn:
|
||||
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
|
||||
Gives an error if the same model name is used as the current or secondary index
|
||||
"""
|
||||
Creates a new SearchSettings row and cancels the previous secondary indexing
|
||||
if any exists.
|
||||
"""
|
||||
if search_settings_new.index_name:
|
||||
logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
# Disallow contextual RAG for cloud deployments.
|
||||
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
# Validate cloud provider exists or create new LiteLLM provider.
|
||||
if search_settings_new.provider_type is not None:
|
||||
cloud_provider = get_embedding_provider_from_provider_type(
|
||||
db_session, provider_type=search_settings_new.provider_type
|
||||
)
|
||||
|
||||
if cloud_provider is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
provider_name=search_settings_new.contextual_rag_llm_provider,
|
||||
model_name=search_settings_new.contextual_rag_llm_name,
|
||||
db_session=db_session,
|
||||
# TODO(andrei): Re-enable.
|
||||
# NOTE Enable integration external dependency tests in test_search_settings.py
|
||||
# when this is reenabled. They are currently skipped
|
||||
logger.error("Setting new search settings is temporarily disabled.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Setting new search settings is temporarily disabled.",
|
||||
)
|
||||
# if search_settings_new.index_name:
|
||||
# logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# # Disallow contextual RAG for cloud deployments
|
||||
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Contextual RAG disabled in Onyx Cloud",
|
||||
# )
|
||||
|
||||
if search_settings_new.index_name is None:
|
||||
# We define index name here.
|
||||
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
if (
|
||||
search_settings_new.model_name == search_settings.model_name
|
||||
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
):
|
||||
index_name += ALT_INDEX_SUFFIX
|
||||
search_values = search_settings_new.model_dump()
|
||||
search_values["index_name"] = index_name
|
||||
new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
else:
|
||||
new_search_settings_request = SavedSearchSettings(
|
||||
**search_settings_new.model_dump()
|
||||
)
|
||||
# # Validate cloud provider exists or create new LiteLLM provider
|
||||
# if search_settings_new.provider_type is not None:
|
||||
# cloud_provider = get_embedding_provider_from_provider_type(
|
||||
# db_session, provider_type=search_settings_new.provider_type
|
||||
# )
|
||||
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# if cloud_provider is None:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
# )
|
||||
|
||||
if secondary_search_settings:
|
||||
# Cancel any background indexing jobs.
|
||||
expire_index_attempts(
|
||||
search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
)
|
||||
# validate_contextual_rag_model(
|
||||
# provider_name=search_settings_new.contextual_rag_llm_provider,
|
||||
# model_name=search_settings_new.contextual_rag_llm_name,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# Mark previous model as a past model directly.
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
|
||||
new_search_settings = create_search_settings(
|
||||
search_settings=new_search_settings_request, db_session=db_session
|
||||
)
|
||||
# if search_settings_new.index_name is None:
|
||||
# # We define index name here
|
||||
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
# if (
|
||||
# search_settings_new.model_name == search_settings.model_name
|
||||
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
# ):
|
||||
# index_name += ALT_INDEX_SUFFIX
|
||||
# search_values = search_settings_new.model_dump()
|
||||
# search_values["index_name"] = index_name
|
||||
# new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
# else:
|
||||
# new_search_settings_request = SavedSearchSettings(
|
||||
# **search_settings_new.model_dump()
|
||||
# )
|
||||
|
||||
# Ensure the document indices have the new index immediately.
|
||||
document_indices = get_all_document_indices(search_settings, new_search_settings)
|
||||
for document_index in document_indices:
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=search_settings.embedding_precision,
|
||||
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
)
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
# Pause index attempts for the currently in-use index to preserve resources.
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=new_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
# if secondary_search_settings:
|
||||
# # Cancel any background indexing jobs
|
||||
# expire_index_attempts(
|
||||
# search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
# )
|
||||
|
||||
db_session.commit()
|
||||
return IdReturn(id=new_search_settings.id)
|
||||
# # Mark previous model as a past model directly
|
||||
# update_search_settings_status(
|
||||
# search_settings=secondary_search_settings,
|
||||
# new_status=IndexModelStatus.PAST,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# new_search_settings = create_search_settings(
|
||||
# search_settings=new_search_settings_request, db_session=db_session
|
||||
# )
|
||||
|
||||
# # Ensure Vespa has the new index immediately
|
||||
# get_multipass_config(search_settings)
|
||||
# get_multipass_config(new_search_settings)
|
||||
# document_index = get_default_document_index(
|
||||
# search_settings, new_search_settings, db_session
|
||||
# )
|
||||
|
||||
# document_index.ensure_indices_exist(
|
||||
# primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
# primary_embedding_precision=search_settings.embedding_precision,
|
||||
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
# )
|
||||
|
||||
# # Pause index attempts for the currently in use index to preserve resources
|
||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
# expire_index_attempts(
|
||||
# search_settings_id=search_settings.id, db_session=db_session
|
||||
# )
|
||||
# for cc_pair in get_connector_credential_pairs(db_session):
|
||||
# resync_cc_pair(
|
||||
# cc_pair=cc_pair,
|
||||
# search_settings_id=new_search_settings.id,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# db_session.commit()
|
||||
# return IdReturn(id=new_search_settings.id)
|
||||
|
||||
|
||||
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
@@ -60,6 +61,7 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -810,6 +812,18 @@ def fetch_chat_file(
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
original_file_name = file_record.display_name
|
||||
if file_record.file_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
# Check if a converted text file exists for .docx files
|
||||
txt_file_name = docx_to_txt_filename(original_file_name)
|
||||
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
|
||||
txt_file_record = file_store.read_file_record(txt_file_id)
|
||||
if txt_file_record:
|
||||
file_record = txt_file_record
|
||||
file_id = txt_file_id
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class StreamingType(Enum):
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
TOOL_CALL_DEBUG = "tool_call_debug"
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
|
||||
|
||||
MEMORY_TOOL_START = "memory_tool_start"
|
||||
MEMORY_TOOL_DELTA = "memory_tool_delta"
|
||||
@@ -259,6 +260,16 @@ class CustomToolDelta(BaseObj):
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
|
||||
class ToolCallArgumentDelta(BaseObj):
|
||||
type: Literal["tool_call_argument_delta"] = (
|
||||
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
|
||||
)
|
||||
|
||||
tool_type: str
|
||||
tool_id: str
|
||||
argument_deltas: dict[str, Any]
|
||||
|
||||
|
||||
################################################
|
||||
# File Reader Packets
|
||||
################################################
|
||||
@@ -379,6 +390,7 @@ PacketObj = Union[
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
ToolCallDebug,
|
||||
ToolCallArgumentDelta,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
|
||||
@@ -60,11 +60,9 @@ class Settings(BaseModel):
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
|
||||
# or the license is expired (GATED_ACCESS).
|
||||
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
|
||||
ee_features_enabled: bool = False
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
|
||||
@@ -56,3 +56,23 @@ def get_built_in_tool_ids() -> list[str]:
|
||||
|
||||
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
|
||||
return BUILT_IN_TOOL_MAP[in_code_tool_id]
|
||||
|
||||
|
||||
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
|
||||
"""Build a mapping from LLM-facing tool name to tool class."""
|
||||
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
|
||||
for cls in BUILT_IN_TOOL_MAP.values():
|
||||
name_attr = cls.__dict__.get("name")
|
||||
if isinstance(name_attr, property) and name_attr.fget is not None:
|
||||
tool_name = name_attr.fget(cls)
|
||||
elif isinstance(name_attr, str):
|
||||
tool_name = name_attr
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
|
||||
)
|
||||
result[tool_name] = cls
|
||||
return result
|
||||
|
||||
|
||||
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()
|
||||
|
||||
@@ -92,3 +92,7 @@ class Tool(abc.ABC, Generic[TOverride]):
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def do_emit_argument_deltas(cls) -> bool:
|
||||
return False
|
||||
|
||||
@@ -376,3 +376,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def do_emit_argument_deltas(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
COMPOSE_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.yml"
|
||||
COMPOSE_DEV_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.dev.yml"
|
||||
|
||||
stop_and_remove_containers() {
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled stop opensearch 2>/dev/null || true
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled rm -f opensearch 2>/dev/null || true
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
stop_and_remove_containers
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Trap errors and output a message, then cleanup
|
||||
@@ -22,26 +12,16 @@ trap 'echo "Error occurred on line $LINENO. Exiting script." >&2; cleanup' ERR
|
||||
|
||||
# Usage of the script with optional volume arguments
|
||||
# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume]
|
||||
# [minio_volume] [--keep-opensearch-data]
|
||||
|
||||
KEEP_OPENSEARCH_DATA=false
|
||||
POSITIONAL_ARGS=()
|
||||
for arg in "$@"; do
|
||||
if [[ "$arg" == "--keep-opensearch-data" ]]; then
|
||||
KEEP_OPENSEARCH_DATA=true
|
||||
else
|
||||
POSITIONAL_ARGS+=("$arg")
|
||||
fi
|
||||
done
|
||||
|
||||
VESPA_VOLUME=${POSITIONAL_ARGS[0]:-""}
|
||||
POSTGRES_VOLUME=${POSITIONAL_ARGS[1]:-""}
|
||||
REDIS_VOLUME=${POSITIONAL_ARGS[2]:-""}
|
||||
MINIO_VOLUME=${POSITIONAL_ARGS[3]:-""}
|
||||
VESPA_VOLUME=${1:-""} # Default is empty if not provided
|
||||
POSTGRES_VOLUME=${2:-""} # Default is empty if not provided
|
||||
REDIS_VOLUME=${3:-""} # Default is empty if not provided
|
||||
MINIO_VOLUME=${4:-""} # Default is empty if not provided
|
||||
|
||||
# Stop and remove the existing containers
|
||||
echo "Stopping and removing existing containers..."
|
||||
stop_and_remove_containers
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
@@ -59,29 +39,6 @@ else
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
|
||||
fi
|
||||
|
||||
# If OPENSEARCH_ADMIN_PASSWORD is not already set, try loading it from
|
||||
# .vscode/.env so existing dev setups that stored it there aren't silently
|
||||
# broken.
|
||||
VSCODE_ENV="$SCRIPT_DIR/../../.vscode/.env"
|
||||
if [[ -z "${OPENSEARCH_ADMIN_PASSWORD:-}" && -f "$VSCODE_ENV" ]]; then
|
||||
set -a
|
||||
# shellcheck source=/dev/null
|
||||
source "$VSCODE_ENV"
|
||||
set +a
|
||||
fi
|
||||
|
||||
# Start the OpenSearch container using the same service from docker-compose that
|
||||
# our users use, setting OPENSEARCH_INITIAL_ADMIN_PASSWORD from the env's
|
||||
# OPENSEARCH_ADMIN_PASSWORD if it exists, else defaulting to StrongPassword123!.
|
||||
# Pass --keep-opensearch-data to preserve the opensearch-data volume across
|
||||
# restarts, else the volume is deleted so the container starts fresh.
|
||||
if [[ "$KEEP_OPENSEARCH_DATA" == "false" ]]; then
|
||||
echo "Deleting opensearch-data volume..."
|
||||
docker volume rm onyx_opensearch-data 2>/dev/null || true
|
||||
fi
|
||||
echo "Starting OpenSearch container..."
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled up --force-recreate -d opensearch
|
||||
|
||||
# Start the Redis container with optional volume
|
||||
echo "Starting Redis container..."
|
||||
if [[ -n "$REDIS_VOLUME" ]]; then
|
||||
@@ -103,6 +60,7 @@ echo "Starting Code Interpreter container..."
|
||||
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
|
||||
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$PARENT_DIR"
|
||||
|
||||
|
||||
10
backend/scripts/restart_opensearch_container.sh
Normal file
10
backend/scripts/restart_opensearch_container.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
# We get OPENSEARCH_ADMIN_PASSWORD from the repo .env file.
|
||||
source "$(dirname "$0")/../../.vscode/.env"
|
||||
|
||||
cd "$(dirname "$0")/../../deployment/docker_compose"
|
||||
|
||||
# Start OpenSearch.
|
||||
echo "Forcefully starting fresh OpenSearch container..."
|
||||
docker compose -f docker-compose.opensearch.yml up --force-recreate -d opensearch
|
||||
@@ -5,8 +5,6 @@ Verifies that:
|
||||
1. extract_ids_from_runnable_connector correctly separates hierarchy nodes from doc IDs
|
||||
2. Extracted hierarchy nodes are correctly upserted to Postgres via upsert_hierarchy_nodes_batch
|
||||
3. Upserting is idempotent (running twice doesn't duplicate nodes)
|
||||
4. Document-to-hierarchy-node linkage is updated during pruning
|
||||
5. link_hierarchy_nodes_to_documents links nodes that are also documents
|
||||
|
||||
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
|
||||
combined with a real PostgreSQL database for verifying persistence.
|
||||
@@ -29,13 +27,9 @@ from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.hierarchy import ensure_source_node_exists
|
||||
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
|
||||
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -95,18 +89,8 @@ def _make_hierarchy_nodes() -> list[PydanticHierarchyNode]:
|
||||
]
|
||||
|
||||
|
||||
DOC_PARENT_MAP = {
|
||||
"msg-001": CHANNEL_A_ID,
|
||||
"msg-002": CHANNEL_A_ID,
|
||||
"msg-003": CHANNEL_B_ID,
|
||||
}
|
||||
|
||||
|
||||
def _make_slim_docs() -> list[SlimDocument | PydanticHierarchyNode]:
|
||||
return [
|
||||
SlimDocument(id=doc_id, parent_hierarchy_raw_node_id=DOC_PARENT_MAP.get(doc_id))
|
||||
for doc_id in SLIM_DOC_IDS
|
||||
]
|
||||
return [SlimDocument(id=doc_id) for doc_id in SLIM_DOC_IDS]
|
||||
|
||||
|
||||
class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
|
||||
@@ -142,31 +126,14 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cleanup_test_data(db_session: Session) -> None:
|
||||
"""Remove all test hierarchy nodes and documents to isolate tests."""
|
||||
for doc_id in SLIM_DOC_IDS:
|
||||
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
|
||||
def _cleanup_test_hierarchy_nodes(db_session: Session) -> None:
|
||||
"""Remove all hierarchy nodes for TEST_SOURCE to isolate tests."""
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == TEST_SOURCE
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _create_test_documents(db_session: Session) -> list[DbDocument]:
|
||||
"""Insert minimal Document rows for our test doc IDs."""
|
||||
docs = []
|
||||
for doc_id in SLIM_DOC_IDS:
|
||||
doc = DbDocument(
|
||||
id=doc_id,
|
||||
semantic_id=doc_id,
|
||||
kg_stage=KGStage.NOT_STARTED,
|
||||
)
|
||||
db_session.add(doc)
|
||||
docs.append(doc)
|
||||
db_session.commit()
|
||||
return docs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -180,14 +147,14 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
|
||||
result = extract_ids_from_runnable_connector(connector, callback=None)
|
||||
|
||||
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
|
||||
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
|
||||
# (hierarchy node IDs are added to doc_ids so they aren't pruned)
|
||||
expected_ids = {
|
||||
CHANNEL_A_ID,
|
||||
CHANNEL_B_ID,
|
||||
CHANNEL_C_ID,
|
||||
*SLIM_DOC_IDS,
|
||||
}
|
||||
assert result.raw_id_to_parent.keys() == expected_ids
|
||||
assert result.doc_ids == expected_ids
|
||||
|
||||
# Hierarchy nodes should be the 3 channels
|
||||
assert len(result.hierarchy_nodes) == 3
|
||||
@@ -198,7 +165,7 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
|
||||
def test_pruning_upserts_hierarchy_nodes_to_db(db_session: Session) -> None:
|
||||
"""Full flow: extract hierarchy nodes from mock connector, upsert to Postgres,
|
||||
then verify the DB state (node count, parent relationships, permissions)."""
|
||||
_cleanup_test_data(db_session)
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
|
||||
# Step 1: ensure the SOURCE node exists (mirrors what the pruning task does)
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
@@ -263,7 +230,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
|
||||
) -> None:
|
||||
"""When the connector's access type is PUBLIC, all hierarchy nodes must be
|
||||
marked is_public=True regardless of their external_access settings."""
|
||||
_cleanup_test_data(db_session)
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -290,7 +257,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
|
||||
def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
|
||||
"""Upserting the same hierarchy nodes twice must not create duplicates.
|
||||
The second call should update existing rows in place."""
|
||||
_cleanup_test_data(db_session)
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -328,7 +295,7 @@ def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
|
||||
|
||||
def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> None:
|
||||
"""Upserting a hierarchy node with changed fields should update the existing row."""
|
||||
_cleanup_test_data(db_session)
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -375,193 +342,3 @@ def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> No
|
||||
assert db_node.is_public is True
|
||||
assert db_node.external_user_emails is not None
|
||||
assert set(db_node.external_user_emails) == {"new_user@example.com"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document-to-hierarchy-node linkage tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_extraction_preserves_parent_hierarchy_raw_node_id(
|
||||
db_session: Session, # noqa: ARG001
|
||||
) -> None:
|
||||
"""extract_ids_from_runnable_connector should carry the
|
||||
parent_hierarchy_raw_node_id from SlimDocument into the raw_id_to_parent dict."""
|
||||
connector = MockSlimConnectorWithPermSync()
|
||||
result = extract_ids_from_runnable_connector(connector, callback=None)
|
||||
|
||||
for doc_id, expected_parent in DOC_PARENT_MAP.items():
|
||||
assert (
|
||||
result.raw_id_to_parent[doc_id] == expected_parent
|
||||
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
|
||||
|
||||
# Hierarchy node entries have None parent (they aren't documents)
|
||||
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
|
||||
assert result.raw_id_to_parent[channel_id] is None
|
||||
|
||||
|
||||
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
|
||||
"""update_document_parent_hierarchy_nodes should set
|
||||
Document.parent_hierarchy_node_id for each document in the mapping."""
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
|
||||
|
||||
# Create documents with no parent set
|
||||
docs = _create_test_documents(db_session)
|
||||
for doc in docs:
|
||||
assert doc.parent_hierarchy_node_id is None
|
||||
|
||||
# Build resolved map (same logic as _resolve_and_update_document_parents)
|
||||
resolved: dict[str, int | None] = {}
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items():
|
||||
resolved[doc_id] = node_id_by_raw.get(raw_parent, source_node.id)
|
||||
|
||||
updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert updated == len(SLIM_DOC_IDS)
|
||||
|
||||
# Verify each document now points to the correct hierarchy node
|
||||
db_session.expire_all()
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items():
|
||||
tmp_doc = db_session.get(DbDocument, doc_id)
|
||||
assert tmp_doc is not None
|
||||
doc = tmp_doc
|
||||
expected_node_id = node_id_by_raw[raw_parent]
|
||||
assert (
|
||||
doc.parent_hierarchy_node_id == expected_node_id
|
||||
), f"Document {doc_id} should point to node for {raw_parent}"
|
||||
|
||||
|
||||
def test_update_document_parent_is_idempotent(db_session: Session) -> None:
|
||||
"""Running update_document_parent_hierarchy_nodes a second time with the
|
||||
same mapping should update zero rows."""
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
|
||||
_create_test_documents(db_session)
|
||||
|
||||
resolved: dict[str, int | None] = {
|
||||
doc_id: node_id_by_raw[raw_parent]
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items()
|
||||
}
|
||||
|
||||
first_updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert first_updated == len(SLIM_DOC_IDS)
|
||||
|
||||
second_updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert second_updated == 0
|
||||
|
||||
|
||||
def test_link_hierarchy_nodes_to_documents_for_confluence(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""For sources in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS (e.g. Confluence),
|
||||
link_hierarchy_nodes_to_documents should set HierarchyNode.document_id
|
||||
when a hierarchy node's raw_node_id matches a document ID."""
|
||||
_cleanup_test_data(db_session)
|
||||
confluence_source = DocumentSource.CONFLUENCE
|
||||
|
||||
# Clean up any existing Confluence hierarchy nodes
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == confluence_source
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
ensure_source_node_exists(db_session, confluence_source, commit=True)
|
||||
|
||||
# Create a hierarchy node whose raw_node_id matches a document ID
|
||||
page_node_id = "confluence-page-123"
|
||||
nodes = [
|
||||
PydanticHierarchyNode(
|
||||
raw_node_id=page_node_id,
|
||||
raw_parent_id=None,
|
||||
display_name="Test Page",
|
||||
link="https://wiki.example.com/page/123",
|
||||
node_type=HierarchyNodeType.PAGE,
|
||||
),
|
||||
]
|
||||
upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=nodes,
|
||||
source=confluence_source,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
|
||||
# Verify the node exists but has no document_id yet
|
||||
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
|
||||
assert db_node is not None
|
||||
assert db_node.document_id is None
|
||||
|
||||
# Create a document with the same ID as the hierarchy node
|
||||
doc = DbDocument(
|
||||
id=page_node_id,
|
||||
semantic_id="Test Page",
|
||||
kg_stage=KGStage.NOT_STARTED,
|
||||
)
|
||||
db_session.add(doc)
|
||||
db_session.commit()
|
||||
|
||||
# Link nodes to documents
|
||||
linked = link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=[page_node_id],
|
||||
source=confluence_source,
|
||||
commit=True,
|
||||
)
|
||||
assert linked == 1
|
||||
|
||||
# Verify the hierarchy node now has document_id set
|
||||
db_session.expire_all()
|
||||
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
|
||||
assert db_node is not None
|
||||
assert db_node.document_id == page_node_id
|
||||
|
||||
# Cleanup
|
||||
db_session.query(DbDocument).filter(DbDocument.id == page_node_id).delete()
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == confluence_source
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""link_hierarchy_nodes_to_documents should return 0 for sources that
|
||||
don't support hierarchy-node-as-document (e.g. Slack, Google Drive)."""
|
||||
linked = link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=SLIM_DOC_IDS,
|
||||
source=TEST_SOURCE, # Slack — not in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS
|
||||
commit=False,
|
||||
)
|
||||
assert linked == 0
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.llm import fetch_default_contextual_rag_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -38,8 +37,6 @@ def _create_llm_provider_and_model(
|
||||
model_name: str,
|
||||
) -> None:
|
||||
"""Insert an LLM provider with a single visible model configuration."""
|
||||
if fetch_existing_llm_provider(name=provider_name, db_session=db_session):
|
||||
return
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
@@ -149,8 +146,8 @@ def baseline_search_settings(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.db.swap_index.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -158,7 +155,6 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices: MagicMock,
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
@@ -200,8 +196,8 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.db.swap_index.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -209,7 +205,6 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices: MagicMock,
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
@@ -271,7 +266,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -279,7 +274,6 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
|
||||
@@ -950,6 +950,7 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
|
||||
@@ -1294,9 +1295,19 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type="python",
|
||||
tool_id="call_replay_test",
|
||||
argument_deltas={"code": code},
|
||||
),
|
||||
),
|
||||
forward=2,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
),
|
||||
forward=False,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -364,6 +365,7 @@ def test_update_contextual_rag_missing_model_name(
|
||||
assert "Provider name and model name are required" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_with_contextual_rag(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -392,6 +394,7 @@ def test_set_new_search_settings_with_contextual_rag(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_without_contextual_rag(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -416,6 +419,7 @@ def test_set_new_search_settings_without_contextual_rag(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_then_update_inference_settings(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -453,6 +457,7 @@ def test_set_new_then_update_inference_settings(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_replaces_previous_secondary(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
|
||||
@@ -281,10 +281,9 @@ class TestApplyLicenseStatusToSettings:
|
||||
}
|
||||
|
||||
|
||||
class TestSettingsDefaults:
|
||||
"""Verify Settings model defaults for CE deployments."""
|
||||
class TestSettingsDefaultEEDisabled:
|
||||
"""Verify the Settings model defaults ee_features_enabled to False."""
|
||||
|
||||
def test_default_ee_features_disabled(self) -> None:
|
||||
"""CE default: ee_features_enabled is False."""
|
||||
settings = Settings()
|
||||
assert settings.ee_features_enabled is False
|
||||
|
||||
584
backend/tests/unit/onyx/chat/test_argument_delta_streaming.py
Normal file
584
backend/tests/unit/onyx/chat/test_argument_delta_streaming.py
Normal file
@@ -0,0 +1,584 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
|
||||
|
||||
def _make_tool_call_delta(
|
||||
index: int = 0,
|
||||
tool_id: str | None = None,
|
||||
name: str | None = None,
|
||||
arguments: str | None = None,
|
||||
function_is_none: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock tool_call_delta matching the LiteLLM streaming shape."""
|
||||
delta = MagicMock()
|
||||
delta.index = index
|
||||
delta.id = tool_id
|
||||
if function_is_none:
|
||||
delta.function = None
|
||||
else:
|
||||
delta.function = MagicMock()
|
||||
delta.function.name = name
|
||||
delta.function.arguments = arguments
|
||||
return delta
|
||||
|
||||
|
||||
def _make_placement() -> Placement:
|
||||
return Placement(turn_index=0, tab_index=0)
|
||||
|
||||
|
||||
def _mock_tool_class(emit: bool = True) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
cls.do_emit_argument_deltas.return_value = emit
|
||||
return cls
|
||||
|
||||
|
||||
def _collect(
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
delta: MagicMock,
|
||||
placement: Placement | None = None,
|
||||
scan_offsets: dict[int, int] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Run maybe_emit_argument_delta and return the yielded packets."""
|
||||
return list(
|
||||
maybe_emit_argument_delta(
|
||||
tc_map,
|
||||
delta,
|
||||
placement or _make_placement(),
|
||||
scan_offsets if scan_offsets is not None else {},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _stream_fragments(
|
||||
fragments: list[str],
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
placement: Placement | None = None,
|
||||
) -> list[str]:
|
||||
"""Feed fragments into maybe_emit_argument_delta one by one, returning
|
||||
all emitted content values concatenated per-key as a flat list."""
|
||||
pl = placement or _make_placement()
|
||||
scan_offsets: dict[int, int] = {}
|
||||
emitted: list[str] = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
delta = _make_tool_call_delta(arguments=frag)
|
||||
for packet in maybe_emit_argument_delta(
|
||||
tc_map, delta, pl, scan_offsets=scan_offsets
|
||||
):
|
||||
obj = packet.obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
for value in obj.argument_deltas.values():
|
||||
emitted.append(value)
|
||||
return emitted
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaGuards:
|
||||
"""Tests for conditions that cause no packet to be emitted."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_does_not_opt_in(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Tools that return False from do_emit_argument_deltas emit nothing."""
|
||||
mock_get_tool.return_value = _mock_tool_class(emit=False)
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_class_unknown(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = None
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "unknown", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_no_argument_fragment(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=None)) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_key_value_not_started(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Key exists in JSON but its string value hasn't begun yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code":'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=":")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_before_any_key(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Only the opening brace has arrived — no key to stream yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": "{"}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="{")) == []
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaBasic:
|
||||
"""Tests for correct packet content and incremental emission."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_packet_with_correct_fields(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "print(1)',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments="print(1)"))
|
||||
|
||||
assert len(packets) == 1
|
||||
obj = packets[0].obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
assert obj.tool_type == "python"
|
||||
assert obj.tool_id == "tc_1"
|
||||
assert obj.argument_deltas == {"code": "print(1)"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_only_new_content_on_subsequent_call(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""After a first emission, subsequent calls emit only the diff."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "abc'}
|
||||
}
|
||||
|
||||
packets_1 = _collect(tc_map, _make_tool_call_delta(arguments="abc"))
|
||||
assert packets_1[0].obj.argument_deltas == {"code": "abc"}
|
||||
|
||||
tc_map[0]["arguments"] = '{"code": "abcdef'
|
||||
packets_2 = _collect(tc_map, _make_tool_call_delta(arguments="def"))
|
||||
assert packets_2[0].obj.argument_deltas == {"code": "def"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_handles_multiple_keys_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""When a second key starts, emissions switch to that key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
|
||||
packets_1 = _collect(tc_map, _make_tool_call_delta(arguments="x"))
|
||||
assert packets_1[0].obj.argument_deltas == {"code": "x"}
|
||||
|
||||
tc_map[0]["arguments"] = '{"code": "x", "output": "hello'
|
||||
packets_2 = _collect(tc_map, _make_tool_call_delta(arguments="hello"))
|
||||
assert packets_2[0].obj.argument_deltas == {"output": "hello"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_spans_key_boundary(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains the end of one value and the start of the next key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
|
||||
packets_1 = _collect(tc_map, _make_tool_call_delta(arguments="x"))
|
||||
assert packets_1[0].obj.argument_deltas == {"code": "x"}
|
||||
|
||||
# Delta carries closing of "code" value + opening of "lang" key + start of value
|
||||
tc_map[0]["arguments"] = '{"code": "xy", "lang": "py'
|
||||
packets_2 = _collect(tc_map, _make_tool_call_delta(arguments='y", "lang": "py'))
|
||||
assert len(packets_2) == 1
|
||||
assert packets_2[0].obj.argument_deltas == {"code": "y", "lang": "py"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_empty_value_emits_nothing(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An empty string value has nothing to emit."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "'}
|
||||
}
|
||||
# Opening quote just arrived, value is empty
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments='"')) == []
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaDecoding:
|
||||
"""Tests verifying that JSON escape sequences are properly decoded."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "line1\\nline2',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments="line1\\nline2"))
|
||||
assert packets[0].obj.argument_deltas == {"code": "line1\nline2"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_tabs(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "\\tindented',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments="\\tindented"))
|
||||
assert packets[0].obj.argument_deltas == {"code": "\tindented"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_quotes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "say \\"hi\\"',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments='say \\"hi\\"'))
|
||||
assert packets[0].obj.argument_deltas == {"code": 'say "hi"'}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_backslashes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "path\\\\dir',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments="path\\\\dir"))
|
||||
assert packets[0].obj.argument_deltas == {"code": "path\\dir"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_unicode_escape(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "\\u0041',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments="\\u0041"))
|
||||
assert packets[0].obj.argument_deltas == {"code": "A"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_escape_at_end_trims_safely(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A trailing backslash (incomplete escape) is handled gracefully."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "hello\\',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments="hello\\"))
|
||||
# "hello" can be decoded; the trailing backslash is trimmed
|
||||
assert packets[0].obj.argument_deltas == {"code": "hello"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_unicode_escape_trims_safely(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A partial \\uXX sequence is trimmed, emitting what can be decoded."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "hello\\u00',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments="hello\\u00"))
|
||||
assert packets[0].obj.argument_deltas == {"code": "hello"}
|
||||
|
||||
|
||||
class TestArgumentDeltaStreamingE2E:
|
||||
"""Simulates realistic sequences of LLM argument deltas to verify
|
||||
the full pipeline produces correct decoded output."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_realistic_python_code_streaming(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams: {"code": "print('hello')\\nprint('world')"}"""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"',
|
||||
"code",
|
||||
'": "',
|
||||
"print(",
|
||||
"'hello')",
|
||||
"\\n",
|
||||
"print(",
|
||||
"'world')",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "print('hello')\nprint('world')"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_streaming_with_tabs_and_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code with tabs and newlines."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"if True:",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"pass",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "if True:\n\tpass"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_split_escape_sequence(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An escape sequence split across two fragments (backslash in one,
|
||||
'n' in the next) should still decode correctly."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "hello',
|
||||
"\\",
|
||||
"n",
|
||||
'world"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "hello\nworld"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_newlines_and_indentation(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams a multi-line function with multiple escape sequences."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"def foo():",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"x = 1",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"return x",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "def foo():\n\tx = 1\n\treturn x"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_two_keys_streamed_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code first, then a second key (language) — both decoded."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = 1",
|
||||
'", "language": "',
|
||||
"python",
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
# Should have emissions for both keys
|
||||
full = "".join(emitted)
|
||||
assert "x = 1" in full
|
||||
assert "python" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_containing_dict_literal(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Python code like `x = {"key": "val"}` contains JSON-like patterns.
|
||||
The escaped quotes inside the *outer* JSON value should prevent the
|
||||
inner `"key":` from being mistaken for a top-level JSON key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# The LLM sends: {"code": "x = {\"key\": \"val\"}"}
|
||||
# The inner quotes are escaped as \" in the JSON value.
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = {",
|
||||
'\\"key\\"',
|
||||
": ",
|
||||
'\\"val\\"',
|
||||
"}",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'x = {"key": "val"}'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_with_colon_in_value(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Colons inside the string value should not confuse key detection."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"url = ",
|
||||
'\\"https://example.com\\"',
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'url = "https://example.com"'
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaEdgeCases:
|
||||
"""Edge cases not covered by the standard test classes."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_function_is_none(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Some delta chunks have function=None (e.g. role-only deltas)."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
delta = _make_tool_call_delta(arguments=None, function_is_none=True)
|
||||
assert _collect(tc_map, delta) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_concurrent_tool_calls(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Two tool calls streaming at different indices in parallel."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "aaa'},
|
||||
1: {"id": "tc_2", "name": "python", "arguments": '{"code": "bbb'},
|
||||
}
|
||||
|
||||
# Delta for index 0
|
||||
packets_0 = _collect(tc_map, _make_tool_call_delta(index=0, arguments="aaa"))
|
||||
assert len(packets_0) == 1
|
||||
assert packets_0[0].obj.tool_id == "tc_1"
|
||||
assert packets_0[0].obj.argument_deltas == {"code": "aaa"}
|
||||
|
||||
# Delta for index 1
|
||||
packets_1 = _collect(tc_map, _make_tool_call_delta(index=1, arguments="bbb"))
|
||||
assert len(packets_1) == 1
|
||||
assert packets_1[0].obj.tool_id == "tc_2"
|
||||
assert packets_1[0].obj.argument_deltas == {"code": "bbb"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_with_four_arguments(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains four complete key-value pairs."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
accumulated = '{"a": "one", "b": "two", "c": "three", "d": "four'
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": accumulated}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments=accumulated))
|
||||
|
||||
assert len(packets) == 1
|
||||
assert packets[0].obj.argument_deltas == {
|
||||
"a": "one",
|
||||
"b": "two",
|
||||
"c": "three",
|
||||
"d": "four",
|
||||
}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_on_second_arg_after_first_complete(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""First argument is fully complete; delta only adds to the second."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"code": "print(1)", "lang": "py',
|
||||
}
|
||||
}
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments="py"))
|
||||
|
||||
assert len(packets) == 1
|
||||
assert packets[0].obj.argument_deltas == {"lang": "py"}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_non_string_values_skipped(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Non-string values (numbers, booleans, null) are skipped — they are
|
||||
available in the final tool-call kickoff packet. String arguments
|
||||
following them are still emitted."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {
|
||||
"id": "tc_1",
|
||||
"name": "python",
|
||||
"arguments": '{"timeout": 30, "code": "hello',
|
||||
}
|
||||
}
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments='30, "code": "hello')
|
||||
)
|
||||
|
||||
assert len(packets) == 1
|
||||
assert packets[0].obj.argument_deltas == {"code": "hello"}
|
||||
@@ -1,12 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { FileDescriptor } from "@/app/app/interfaces";
|
||||
import "katex/dist/katex.min.css";
|
||||
import MessageSwitcher from "@/app/app/message/MessageSwitcher";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import useScreenSize from "@/hooks/useScreenSize";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgEdit } from "@opal/icons";
|
||||
@@ -138,7 +137,6 @@ const HumanMessage = React.memo(function HumanMessage({
|
||||
const [content, setContent] = useState(initialContent);
|
||||
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const { isMobile } = useScreenSize();
|
||||
|
||||
// Use nodeId for switching (finding position in siblings)
|
||||
const indexInSiblings = otherMessagesCanSwitchTo?.indexOf(nodeId);
|
||||
@@ -170,104 +168,119 @@ const HumanMessage = React.memo(function HumanMessage({
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const copyEditButton = useMemo(
|
||||
() => (
|
||||
<div className="flex flex-row flex-shrink px-1 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<CopyIconButton
|
||||
getCopyText={() => content}
|
||||
prominence="tertiary"
|
||||
data-testid="HumanMessage/copy-button"
|
||||
/>
|
||||
<Button
|
||||
icon={SvgEdit}
|
||||
prominence="tertiary"
|
||||
tooltip="Edit"
|
||||
onClick={() => setIsEditing(true)}
|
||||
data-testid="HumanMessage/edit-button"
|
||||
/>
|
||||
</div>
|
||||
),
|
||||
[content]
|
||||
);
|
||||
|
||||
return (
|
||||
<div
|
||||
id="onyx-human-message"
|
||||
className="group flex flex-col justify-end w-full relative"
|
||||
>
|
||||
<FileDisplay alignBubble files={files || []} />
|
||||
{isEditing ? (
|
||||
<MessageEditing
|
||||
content={content}
|
||||
onSubmitEdit={(editedContent) => {
|
||||
// Don't update UI for edits that can't be persisted
|
||||
if (messageId === undefined || messageId === null) {
|
||||
setIsEditing(false);
|
||||
return;
|
||||
}
|
||||
onEdit?.(editedContent, messageId);
|
||||
setContent(editedContent);
|
||||
setIsEditing(false);
|
||||
}}
|
||||
onCancelEdit={() => setIsEditing(false)}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex justify-end">
|
||||
{onEdit && !isMobile && copyEditButton}
|
||||
<div className="md:max-w-[37.5rem]">
|
||||
<div
|
||||
className={
|
||||
"max-w-[30rem] md:max-w-[37.5rem] whitespace-break-spaces break-anywhere rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
|
||||
<div className="md:flex md:flex-wrap relative justify-end break-words">
|
||||
{isEditing ? (
|
||||
<MessageEditing
|
||||
content={content}
|
||||
onSubmitEdit={(editedContent) => {
|
||||
// Don't update UI for edits that can't be persisted
|
||||
if (messageId === undefined || messageId === null) {
|
||||
setIsEditing(false);
|
||||
return;
|
||||
}
|
||||
onCopy={(e) => {
|
||||
const selection = window.getSelection();
|
||||
if (selection) {
|
||||
e.preventDefault();
|
||||
const text = selection
|
||||
.toString()
|
||||
.replace(/\n{2,}/g, "\n")
|
||||
.trim();
|
||||
e.clipboardData.setData("text/plain", text);
|
||||
onEdit?.(editedContent, messageId);
|
||||
setContent(editedContent);
|
||||
setIsEditing(false);
|
||||
}}
|
||||
onCancelEdit={() => setIsEditing(false)}
|
||||
/>
|
||||
) : typeof content === "string" ? (
|
||||
<>
|
||||
<div className="md:max-w-[37.5rem] flex basis-[100%] md:basis-auto justify-end md:order-1">
|
||||
<div
|
||||
className={
|
||||
"max-w-[30rem] md:max-w-[37.5rem] whitespace-break-spaces break-anywhere rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
as="p"
|
||||
className="inline-block align-middle"
|
||||
mainContentBody
|
||||
onCopy={(e) => {
|
||||
const selection = window.getSelection();
|
||||
if (selection) {
|
||||
e.preventDefault();
|
||||
const text = selection
|
||||
.toString()
|
||||
.replace(/\n{2,}/g, "\n")
|
||||
.trim();
|
||||
e.clipboardData.setData("text/plain", text);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{content}
|
||||
</Text>
|
||||
<Text
|
||||
as="p"
|
||||
className="inline-block align-middle"
|
||||
mainContentBody
|
||||
>
|
||||
{content}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{onEdit && !isEditing && (
|
||||
<div className="absolute md:relative right-0 z-content flex flex-row p-1 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<CopyIconButton
|
||||
getCopyText={() => content}
|
||||
prominence="tertiary"
|
||||
data-testid="HumanMessage/copy-button"
|
||||
/>
|
||||
<Button
|
||||
icon={SvgEdit}
|
||||
prominence="tertiary"
|
||||
tooltip="Edit"
|
||||
onClick={() => setIsEditing(true)}
|
||||
data-testid="HumanMessage/edit-button"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div
|
||||
className={cn(
|
||||
"my-auto",
|
||||
onEdit && !isEditing
|
||||
? "opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
: "invisible"
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
icon={SvgEdit}
|
||||
onClick={() => setIsEditing(true)}
|
||||
prominence="tertiary"
|
||||
tooltip="Edit"
|
||||
/>
|
||||
</div>
|
||||
<div className="ml-auto rounded-lg p-1">{content}</div>
|
||||
</>
|
||||
)}
|
||||
<div className="md:min-w-[100%] flex justify-end order-1 mt-1">
|
||||
{currentMessageInd !== undefined &&
|
||||
onMessageSelection &&
|
||||
otherMessagesCanSwitchTo &&
|
||||
otherMessagesCanSwitchTo.length > 1 && (
|
||||
<MessageSwitcher
|
||||
disableForStreaming={disableSwitchingForStreaming}
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() => {
|
||||
stopGenerating();
|
||||
const prevMessage = getPreviousMessage();
|
||||
if (prevMessage !== undefined) {
|
||||
onMessageSelection(prevMessage);
|
||||
}
|
||||
}}
|
||||
handleNext={() => {
|
||||
stopGenerating();
|
||||
const nextMessage = getNextMessage();
|
||||
if (nextMessage !== undefined) {
|
||||
onMessageSelection(nextMessage);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex justify-end pt-1">
|
||||
{!isEditing && onEdit && isMobile && copyEditButton}
|
||||
{currentMessageInd !== undefined &&
|
||||
onMessageSelection &&
|
||||
otherMessagesCanSwitchTo &&
|
||||
otherMessagesCanSwitchTo.length > 1 && (
|
||||
<MessageSwitcher
|
||||
disableForStreaming={disableSwitchingForStreaming}
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() => {
|
||||
stopGenerating();
|
||||
const prevMessage = getPreviousMessage();
|
||||
if (prevMessage !== undefined) {
|
||||
onMessageSelection(prevMessage);
|
||||
}
|
||||
}}
|
||||
handleNext={() => {
|
||||
stopGenerating();
|
||||
const nextMessage = getNextMessage();
|
||||
if (nextMessage !== undefined) {
|
||||
onMessageSelection(nextMessage);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
export default function EEFeatureRedirect() {
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
toast.error(
|
||||
"This feature requires a license. Please upgrade your plan to access."
|
||||
);
|
||||
router.replace("/app");
|
||||
}, [router]);
|
||||
|
||||
return null;
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
import { SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED } from "@/lib/constants";
|
||||
import { fetchStandardSettingsSS } from "@/components/settings/lib";
|
||||
import EEFeatureRedirect from "@/app/ee/EEFeatureRedirect";
|
||||
|
||||
export default async function AdminLayout({
|
||||
children,
|
||||
@@ -9,7 +8,13 @@ export default async function AdminLayout({
|
||||
}) {
|
||||
// First check build-time constant (fast path)
|
||||
if (!SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED) {
|
||||
return <EEFeatureRedirect />;
|
||||
return (
|
||||
<div className="flex h-screen">
|
||||
<div className="mx-auto my-auto text-lg font-bold text-red-500">
|
||||
This functionality is only available in the Enterprise Edition :(
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Then check runtime license status (for license enforcement mode)
|
||||
@@ -26,7 +31,13 @@ export default async function AdminLayout({
|
||||
return children;
|
||||
}
|
||||
|
||||
return <EEFeatureRedirect />;
|
||||
return (
|
||||
<div className="flex h-screen">
|
||||
<div className="mx-auto my-auto text-lg font-bold text-red-500">
|
||||
This functionality requires an active Enterprise license.
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
@@ -484,8 +484,12 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
toggleDeepResearch={toggleDeepResearch}
|
||||
toggleDocumentSidebar={() => {}}
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
removeDocs={() => {}}
|
||||
retrievalEnabled={retrievalEnabled}
|
||||
selectedDocuments={[]}
|
||||
initialMessage={message}
|
||||
stopGenerating={stopGenerating}
|
||||
onSubmit={handleChatInputSubmit}
|
||||
|
||||
@@ -23,7 +23,8 @@ export interface AppModeProviderProps {
|
||||
export function AppModeProvider({ children }: AppModeProviderProps) {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { user } = useUser();
|
||||
const { isSearchModeAvailable } = useSettingsContext();
|
||||
const settings = useSettingsContext();
|
||||
const { isSearchModeAvailable } = settings;
|
||||
|
||||
const persistedMode = user?.preferences?.default_app_mode;
|
||||
const [appMode, setAppModeState] = useState<AppMode>("chat");
|
||||
|
||||
@@ -11,8 +11,21 @@ import {
|
||||
* Hook to fetch billing information from Stripe.
|
||||
*
|
||||
* Works for both cloud and self-hosted deployments:
|
||||
* - Cloud: fetches from /api/tenants/billing-information
|
||||
* - Cloud: fetches from /api/tenants/billing-information (legacy endpoint)
|
||||
* - Self-hosted: fetches from /api/admin/billing/billing-information
|
||||
*
|
||||
* Returns subscription status, seats, billing period, etc.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* const { data, isLoading, error, refresh } = useBillingInformation();
|
||||
*
|
||||
* if (isLoading) return <Loading />;
|
||||
* if (error) return <Error />;
|
||||
* if (!data || !hasActiveSubscription(data)) return <NoSubscription />;
|
||||
*
|
||||
* return <BillingDetails billing={data} />;
|
||||
* ```
|
||||
*/
|
||||
export function useBillingInformation() {
|
||||
const url = NEXT_PUBLIC_CLOUD_ENABLED
|
||||
@@ -25,9 +38,16 @@ export function useBillingInformation() {
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnReconnect: false,
|
||||
dedupingInterval: 30000,
|
||||
// Don't auto-retry on errors (circuit breaker will block requests anyway)
|
||||
shouldRetryOnError: false,
|
||||
// Keep previous data while revalidating to prevent UI flashing
|
||||
keepPreviousData: true,
|
||||
});
|
||||
|
||||
return { data, isLoading, error, refresh: mutate };
|
||||
return {
|
||||
data,
|
||||
isLoading,
|
||||
error,
|
||||
refresh: mutate,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -7,9 +7,23 @@ import { LicenseStatus } from "@/lib/billing/interfaces";
|
||||
/**
|
||||
* Hook to fetch license status for self-hosted deployments.
|
||||
*
|
||||
* Skips the fetch on cloud deployments (uses tenant auth instead).
|
||||
* Returns license information including seats, expiry, and status.
|
||||
* Only fetches for self-hosted deployments (cloud uses tenant auth instead).
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* const { data, isLoading, error, refresh } = useLicense();
|
||||
*
|
||||
* if (isLoading) return <Loading />;
|
||||
* if (error) return <Error />;
|
||||
* if (!data?.has_license) return <NoLicense />;
|
||||
*
|
||||
* return <LicenseDetails license={data} />;
|
||||
* ```
|
||||
*/
|
||||
export function useLicense() {
|
||||
// Only fetch license for self-hosted deployments
|
||||
// Cloud deployments use tenant-based auth, not license files
|
||||
const url = NEXT_PUBLIC_CLOUD_ENABLED ? null : "/api/license";
|
||||
|
||||
const { data, error, mutate, isLoading } = useSWR<LicenseStatus>(
|
||||
@@ -24,14 +38,20 @@ export function useLicense() {
|
||||
}
|
||||
);
|
||||
|
||||
if (!url) {
|
||||
// Return empty state for cloud deployments
|
||||
if (NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
return {
|
||||
data: undefined,
|
||||
data: null,
|
||||
isLoading: false,
|
||||
error: undefined,
|
||||
refresh: () => Promise.resolve(undefined),
|
||||
};
|
||||
}
|
||||
|
||||
return { data, isLoading, error, refresh: mutate };
|
||||
return {
|
||||
data,
|
||||
isLoading,
|
||||
error,
|
||||
refresh: mutate,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -46,8 +46,8 @@ export interface Settings {
|
||||
// Onyx Craft (Build Mode) feature flag
|
||||
onyx_craft_enabled?: boolean;
|
||||
|
||||
// Whether EE features are unlocked (user has a valid enterprise license).
|
||||
// Controls UI visibility of EE features like user groups, analytics, RBAC.
|
||||
// Enterprise features flag - controlled by license enforcement at runtime
|
||||
// True when user has a valid license, False for community edition
|
||||
ee_features_enabled?: boolean;
|
||||
|
||||
// Seat usage - populated when seat limit is exceeded
|
||||
|
||||
@@ -190,17 +190,14 @@ function AttachmentItemLayout({
|
||||
alignItems="center"
|
||||
gap={1.5}
|
||||
>
|
||||
<div className="flex-1 min-w-0">
|
||||
<Content
|
||||
title={title}
|
||||
description={description}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
widthVariant="full"
|
||||
/>
|
||||
</div>
|
||||
<Content
|
||||
title={title}
|
||||
description={description}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
{middleText && (
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex-1">
|
||||
<Truncated text03 secondaryBody>
|
||||
{middleText}
|
||||
</Truncated>
|
||||
|
||||
@@ -42,13 +42,8 @@ export const NEXT_PUBLIC_CUSTOM_REFRESH_URL =
|
||||
|
||||
// NOTE: this should ONLY be used on the server-side. If used client side,
|
||||
// it will not be accurate (will always be false).
|
||||
// Mirrors backend logic: EE is enabled if EITHER the legacy flag OR license
|
||||
// enforcement is active. LICENSE_ENFORCEMENT_ENABLED defaults to true on the
|
||||
// backend, so we treat undefined as enabled here to match.
|
||||
export const SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED =
|
||||
process.env.ENABLE_PAID_ENTERPRISE_EDITION_FEATURES?.toLowerCase() ===
|
||||
"true" ||
|
||||
process.env.LICENSE_ENFORCEMENT_ENABLED?.toLowerCase() !== "false";
|
||||
process.env.ENABLE_PAID_ENTERPRISE_EDITION_FEATURES?.toLowerCase() === "true";
|
||||
// NOTE: since this is a `NEXT_PUBLIC_` variable, it will be set at
|
||||
// build-time
|
||||
// TODO: consider moving this to an API call so that the api_server
|
||||
|
||||
@@ -51,6 +51,16 @@ function ToastContainer() {
|
||||
}, ANIMATION_DURATION);
|
||||
}, []);
|
||||
|
||||
// NOTE (@raunakab):
|
||||
//
|
||||
// Keep this here for debugging purposes.
|
||||
// useOnMount(() => {
|
||||
// toast.success("Test success toast", { duration: Infinity });
|
||||
// toast.error("Test error toast", { duration: Infinity });
|
||||
// toast.warning("Test warning toast", { duration: Infinity });
|
||||
// toast.info("Test info toast", { duration: Infinity });
|
||||
// });
|
||||
|
||||
if (visible.length === 0) return null;
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import {
|
||||
type Table,
|
||||
type ColumnDef,
|
||||
type RowData,
|
||||
type VisibilityState,
|
||||
} from "@tanstack/react-table";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgColumn, SvgCheck } from "@opal/icons";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import Divider from "@/refresh-components/Divider";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Popover UI
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface ColumnVisibilityPopoverProps<TData extends RowData = RowData> {
|
||||
table: Table<TData>;
|
||||
columnVisibility: VisibilityState;
|
||||
size?: "regular" | "small";
|
||||
}
|
||||
|
||||
function ColumnVisibilityPopover<TData extends RowData>({
|
||||
table,
|
||||
columnVisibility,
|
||||
size = "regular",
|
||||
}: ColumnVisibilityPopoverProps<TData>) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const hideableColumns = table
|
||||
.getAllLeafColumns()
|
||||
.filter((col) => col.getCanHide());
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<Button
|
||||
icon={SvgColumn}
|
||||
transient={open}
|
||||
size={size === "small" ? "sm" : "md"}
|
||||
prominence="internal"
|
||||
tooltip="Columns"
|
||||
/>
|
||||
</Popover.Trigger>
|
||||
|
||||
<Popover.Content width="lg" align="end" side="bottom">
|
||||
<Divider showTitle text="Shown Columns" />
|
||||
<Popover.Menu>
|
||||
{hideableColumns.map((column) => {
|
||||
const isVisible = columnVisibility[column.id] !== false;
|
||||
const label =
|
||||
typeof column.columnDef.header === "string"
|
||||
? column.columnDef.header
|
||||
: column.id;
|
||||
|
||||
return (
|
||||
<LineItem
|
||||
key={column.id}
|
||||
selected={isVisible}
|
||||
emphasized
|
||||
rightChildren={isVisible ? <SvgCheck size={16} /> : undefined}
|
||||
onClick={() => {
|
||||
column.toggleVisibility();
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
</LineItem>
|
||||
);
|
||||
})}
|
||||
</Popover.Menu>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Column definition factory
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CreateColumnVisibilityColumnOptions {
|
||||
size?: "regular" | "small";
|
||||
}
|
||||
|
||||
function createColumnVisibilityColumn<TData>(
|
||||
options?: CreateColumnVisibilityColumnOptions
|
||||
): ColumnDef<TData, unknown> {
|
||||
return {
|
||||
id: "__columnVisibility",
|
||||
size: 44,
|
||||
enableHiding: false,
|
||||
enableSorting: false,
|
||||
enableResizing: false,
|
||||
header: ({ table }) => (
|
||||
<ColumnVisibilityPopover
|
||||
table={table}
|
||||
columnVisibility={table.getState().columnVisibility}
|
||||
size={options?.size}
|
||||
/>
|
||||
),
|
||||
cell: () => null,
|
||||
};
|
||||
}
|
||||
|
||||
export { ColumnVisibilityPopover, createColumnVisibilityColumn };
|
||||
@@ -1,455 +0,0 @@
|
||||
"use client";
|
||||
"use no memo";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { flexRender } from "@tanstack/react-table";
|
||||
import useDataTable, {
|
||||
toOnyxSortDirection,
|
||||
} from "@/refresh-components/table/hooks/useDataTable";
|
||||
import useColumnWidths from "@/refresh-components/table/hooks/useColumnWidths";
|
||||
import useDraggableRows from "@/refresh-components/table/hooks/useDraggableRows";
|
||||
import Table from "@/refresh-components/table/Table";
|
||||
import TableHeader from "@/refresh-components/table/TableHeader";
|
||||
import TableBody from "@/refresh-components/table/TableBody";
|
||||
import TableRow from "@/refresh-components/table/TableRow";
|
||||
import TableHead from "@/refresh-components/table/TableHead";
|
||||
import TableCell from "@/refresh-components/table/TableCell";
|
||||
import TableQualifier from "@/refresh-components/table/TableQualifier";
|
||||
import QualifierContainer from "@/refresh-components/table/QualifierContainer";
|
||||
import ActionsContainer from "@/refresh-components/table/ActionsContainer";
|
||||
import DragOverlayRow from "@/refresh-components/table/DragOverlayRow";
|
||||
import Footer from "@/refresh-components/table/Footer";
|
||||
import { TableSizeProvider } from "@/refresh-components/table/TableSizeContext";
|
||||
import { ColumnVisibilityPopover } from "@/refresh-components/table/ColumnVisibilityPopover";
|
||||
import { SortingPopover } from "@/refresh-components/table/SortingPopover";
|
||||
import type { WidthConfig } from "@/refresh-components/table/hooks/useColumnWidths";
|
||||
import type { ColumnDef } from "@tanstack/react-table";
|
||||
import type {
|
||||
DataTableProps,
|
||||
DataTableFooterConfig,
|
||||
OnyxColumnDef,
|
||||
OnyxDataColumn,
|
||||
OnyxQualifierColumn,
|
||||
OnyxActionsColumn,
|
||||
} from "@/refresh-components/table/types";
|
||||
import type { TableSize } from "@/refresh-components/table/TableSizeContext";
|
||||
|
||||
const noopGetRowId = () => "";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal: resolve size-dependent widths and build TanStack columns
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface ProcessedColumns<TData> {
|
||||
tanstackColumns: ColumnDef<TData, any>[];
|
||||
widthConfig: WidthConfig;
|
||||
qualifierColumn: OnyxQualifierColumn<TData> | null;
|
||||
/** Map from column ID → OnyxColumnDef for dispatch in render loops. */
|
||||
columnKindMap: Map<string, OnyxColumnDef<TData>>;
|
||||
}
|
||||
|
||||
function processColumns<TData>(
|
||||
columns: OnyxColumnDef<TData>[],
|
||||
size: TableSize
|
||||
): ProcessedColumns<TData> {
|
||||
const tanstackColumns: ColumnDef<TData, any>[] = [];
|
||||
const fixedColumnIds = new Set<string>();
|
||||
const columnWeights: Record<string, number> = {};
|
||||
const columnMinWidths: Record<string, number> = {};
|
||||
const columnKindMap = new Map<string, OnyxColumnDef<TData>>();
|
||||
let qualifierColumn: OnyxQualifierColumn<TData> | null = null;
|
||||
|
||||
for (const col of columns) {
|
||||
const resolvedWidth =
|
||||
typeof col.width === "function" ? col.width(size) : col.width;
|
||||
|
||||
// Clone def to avoid mutating the caller's column definitions
|
||||
const clonedDef: ColumnDef<TData, any> = {
|
||||
...col.def,
|
||||
id: col.id,
|
||||
size:
|
||||
"fixed" in resolvedWidth ? resolvedWidth.fixed : resolvedWidth.weight,
|
||||
};
|
||||
|
||||
tanstackColumns.push(clonedDef);
|
||||
|
||||
const id = col.id;
|
||||
columnKindMap.set(id, col);
|
||||
|
||||
if ("fixed" in resolvedWidth) {
|
||||
fixedColumnIds.add(id);
|
||||
} else {
|
||||
columnWeights[id] = resolvedWidth.weight;
|
||||
columnMinWidths[id] = resolvedWidth.minWidth ?? 50;
|
||||
}
|
||||
|
||||
if (col.kind === "qualifier") qualifierColumn = col;
|
||||
}
|
||||
|
||||
return {
|
||||
tanstackColumns,
|
||||
widthConfig: { fixedColumnIds, columnWeights, columnMinWidths },
|
||||
qualifierColumn,
|
||||
columnKindMap,
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DataTable component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Config-driven table component that wires together `useDataTable`,
|
||||
* `useColumnWidths`, and `useDraggableRows` automatically.
|
||||
*
|
||||
* Full flexibility via the column definitions from `createTableColumns()`.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* const tc = createTableColumns<TeamMember>();
|
||||
* const columns = [
|
||||
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
|
||||
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
|
||||
* tc.column("email", { header: "Email", weight: 28 }),
|
||||
* tc.actions(),
|
||||
* ];
|
||||
*
|
||||
* <DataTable data={items} columns={columns} footer={{ mode: "selection" }} />
|
||||
* ```
|
||||
*/
|
||||
export default function DataTable<TData>(props: DataTableProps<TData>) {
|
||||
const {
|
||||
data,
|
||||
columns,
|
||||
pageSize,
|
||||
initialSorting,
|
||||
initialColumnVisibility,
|
||||
draggable,
|
||||
footer,
|
||||
size = "regular",
|
||||
onRowClick,
|
||||
height,
|
||||
headerBackground,
|
||||
} = props;
|
||||
|
||||
const effectivePageSize = pageSize ?? (footer ? 10 : data.length);
|
||||
|
||||
// 1. Process columns (memoized on columns + size)
|
||||
const { tanstackColumns, widthConfig, qualifierColumn, columnKindMap } =
|
||||
useMemo(() => processColumns(columns, size), [columns, size]);
|
||||
|
||||
// 2. Call useDataTable
|
||||
const {
|
||||
table,
|
||||
currentPage,
|
||||
totalPages,
|
||||
totalItems,
|
||||
setPage,
|
||||
pageSize: resolvedPageSize,
|
||||
selectionState,
|
||||
selectedCount,
|
||||
clearSelection,
|
||||
toggleAllPageRowsSelected,
|
||||
isAllPageRowsSelected,
|
||||
} = useDataTable({
|
||||
data,
|
||||
columns: tanstackColumns,
|
||||
pageSize: effectivePageSize,
|
||||
initialSorting,
|
||||
initialColumnVisibility,
|
||||
});
|
||||
|
||||
// 3. Call useColumnWidths
|
||||
const { containerRef, columnWidths, createResizeHandler } = useColumnWidths({
|
||||
headers: table.getHeaderGroups()[0]?.headers ?? [],
|
||||
...widthConfig,
|
||||
});
|
||||
|
||||
// 4. Call useDraggableRows (conditional)
|
||||
const draggableReturn = useDraggableRows({
|
||||
data,
|
||||
getRowId: draggable?.getRowId ?? noopGetRowId,
|
||||
enabled: !!draggable && table.getState().sorting.length === 0,
|
||||
onReorder: draggable?.onReorder,
|
||||
});
|
||||
|
||||
const hasDraggable = !!draggable;
|
||||
const rowVariant = hasDraggable ? "table" : "list";
|
||||
|
||||
const isSelectable =
|
||||
qualifierColumn != null && qualifierColumn.selectable !== false;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Render
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function renderContent() {
|
||||
return (
|
||||
<div>
|
||||
<div
|
||||
className="overflow-x-auto"
|
||||
ref={containerRef}
|
||||
style={{
|
||||
...(height != null
|
||||
? {
|
||||
maxHeight:
|
||||
typeof height === "number" ? `${height}px` : height,
|
||||
overflowY: "auto" as const,
|
||||
}
|
||||
: undefined),
|
||||
...(headerBackground
|
||||
? ({
|
||||
"--table-header-bg": headerBackground,
|
||||
} as React.CSSProperties)
|
||||
: undefined),
|
||||
}}
|
||||
>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
{table.getHeaderGroups().map((headerGroup) => (
|
||||
<TableRow key={headerGroup.id}>
|
||||
{headerGroup.headers.map((header, headerIndex) => {
|
||||
const colDef = columnKindMap.get(header.id);
|
||||
|
||||
// Qualifier header
|
||||
if (colDef?.kind === "qualifier") {
|
||||
if (qualifierColumn?.header === false) {
|
||||
return (
|
||||
<QualifierContainer key={header.id} type="head" />
|
||||
);
|
||||
}
|
||||
return (
|
||||
<QualifierContainer key={header.id} type="head">
|
||||
<TableQualifier
|
||||
content={
|
||||
qualifierColumn?.headerContentType ?? "simple"
|
||||
}
|
||||
selectable={isSelectable}
|
||||
selected={isSelectable && isAllPageRowsSelected}
|
||||
onSelectChange={
|
||||
isSelectable
|
||||
? (checked) =>
|
||||
toggleAllPageRowsSelected(checked)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</QualifierContainer>
|
||||
);
|
||||
}
|
||||
|
||||
// Actions header
|
||||
if (colDef?.kind === "actions") {
|
||||
const actionsDef = colDef as OnyxActionsColumn<TData>;
|
||||
return (
|
||||
<ActionsContainer key={header.id} type="head">
|
||||
{actionsDef.showColumnVisibility !== false && (
|
||||
<ColumnVisibilityPopover
|
||||
table={table}
|
||||
columnVisibility={
|
||||
table.getState().columnVisibility
|
||||
}
|
||||
size={size}
|
||||
/>
|
||||
)}
|
||||
{actionsDef.showSorting !== false && (
|
||||
<SortingPopover
|
||||
table={table}
|
||||
sorting={table.getState().sorting}
|
||||
size={size}
|
||||
footerText={actionsDef.sortingFooterText}
|
||||
/>
|
||||
)}
|
||||
</ActionsContainer>
|
||||
);
|
||||
}
|
||||
|
||||
// Data / Display header
|
||||
const canSort = header.column.getCanSort();
|
||||
const sortDir = header.column.getIsSorted();
|
||||
const nextHeader = headerGroup.headers[headerIndex + 1];
|
||||
const canResize =
|
||||
header.column.getCanResize() &&
|
||||
!!nextHeader &&
|
||||
!widthConfig.fixedColumnIds.has(nextHeader.id);
|
||||
|
||||
const dataCol =
|
||||
colDef?.kind === "data"
|
||||
? (colDef as OnyxDataColumn<TData>)
|
||||
: null;
|
||||
|
||||
return (
|
||||
<TableHead
|
||||
key={header.id}
|
||||
width={columnWidths[header.id]}
|
||||
sorted={
|
||||
canSort ? toOnyxSortDirection(sortDir) : undefined
|
||||
}
|
||||
onSort={
|
||||
canSort
|
||||
? () => header.column.toggleSorting()
|
||||
: undefined
|
||||
}
|
||||
icon={dataCol?.icon}
|
||||
resizable={canResize}
|
||||
onResizeStart={
|
||||
canResize
|
||||
? createResizeHandler(header.id, nextHeader.id)
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{flexRender(
|
||||
header.column.columnDef.header,
|
||||
header.getContext()
|
||||
)}
|
||||
</TableHead>
|
||||
);
|
||||
})}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableHeader>
|
||||
|
||||
<TableBody
|
||||
dndSortable={hasDraggable ? draggableReturn : undefined}
|
||||
renderDragOverlay={
|
||||
hasDraggable
|
||||
? (activeId) => {
|
||||
const row = table
|
||||
.getRowModel()
|
||||
.rows.find(
|
||||
(r) => draggable!.getRowId(r.original) === activeId
|
||||
);
|
||||
if (!row) return null;
|
||||
return <DragOverlayRow row={row} variant={rowVariant} />;
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{table.getRowModel().rows.map((row) => {
|
||||
const rowId = hasDraggable
|
||||
? draggable!.getRowId(row.original)
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<TableRow
|
||||
key={row.id}
|
||||
variant={rowVariant}
|
||||
sortableId={rowId}
|
||||
selected={row.getIsSelected()}
|
||||
onClick={() => {
|
||||
if (onRowClick) {
|
||||
onRowClick(row.original);
|
||||
} else if (isSelectable) {
|
||||
row.toggleSelected();
|
||||
}
|
||||
}}
|
||||
>
|
||||
{row.getVisibleCells().map((cell) => {
|
||||
const cellColDef = columnKindMap.get(cell.column.id);
|
||||
|
||||
// Qualifier cell
|
||||
if (cellColDef?.kind === "qualifier") {
|
||||
const qDef = cellColDef as OnyxQualifierColumn<TData>;
|
||||
return (
|
||||
<QualifierContainer
|
||||
key={cell.id}
|
||||
type="cell"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<TableQualifier
|
||||
content={qDef.content}
|
||||
initials={qDef.getInitials?.(row.original)}
|
||||
icon={qDef.getIcon?.(row.original)}
|
||||
imageSrc={qDef.getImageSrc?.(row.original)}
|
||||
selectable={isSelectable}
|
||||
selected={isSelectable && row.getIsSelected()}
|
||||
onSelectChange={
|
||||
isSelectable
|
||||
? (checked) => {
|
||||
row.toggleSelected(checked);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</QualifierContainer>
|
||||
);
|
||||
}
|
||||
|
||||
// Actions cell
|
||||
if (cellColDef?.kind === "actions") {
|
||||
return (
|
||||
<ActionsContainer key={cell.id} type="cell">
|
||||
{flexRender(
|
||||
cell.column.columnDef.cell,
|
||||
cell.getContext()
|
||||
)}
|
||||
</ActionsContainer>
|
||||
);
|
||||
}
|
||||
|
||||
// Data / Display cell
|
||||
return (
|
||||
<TableCell key={cell.id}>
|
||||
{flexRender(
|
||||
cell.column.columnDef.cell,
|
||||
cell.getContext()
|
||||
)}
|
||||
</TableCell>
|
||||
);
|
||||
})}
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
{footer && renderFooter(footer)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function renderFooter(footerConfig: DataTableFooterConfig) {
|
||||
if (footerConfig.mode === "selection") {
|
||||
return (
|
||||
<Footer
|
||||
mode="selection"
|
||||
multiSelect={footerConfig.multiSelect !== false}
|
||||
selectionState={selectionState}
|
||||
selectedCount={selectedCount}
|
||||
onClear={footerConfig.onClear ?? clearSelection}
|
||||
onView={footerConfig.onView}
|
||||
pageSize={resolvedPageSize}
|
||||
totalItems={totalItems}
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={setPage}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Summary mode
|
||||
const rangeStart =
|
||||
totalItems === 0
|
||||
? 0
|
||||
: !isFinite(resolvedPageSize)
|
||||
? 1
|
||||
: (currentPage - 1) * resolvedPageSize + 1;
|
||||
const rangeEnd = !isFinite(resolvedPageSize)
|
||||
? totalItems
|
||||
: Math.min(currentPage * resolvedPageSize, totalItems);
|
||||
|
||||
return (
|
||||
<Footer
|
||||
mode="summary"
|
||||
rangeStart={rangeStart}
|
||||
rangeEnd={rangeEnd}
|
||||
totalItems={totalItems}
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={setPage}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <TableSizeProvider size={size}>{renderContent()}</TableSizeProvider>;
|
||||
}
|
||||
@@ -1,260 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Button } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Pagination from "@/refresh-components/table/Pagination";
|
||||
import { useTableSize } from "@/refresh-components/table/TableSizeContext";
|
||||
import type { TableSize } from "@/refresh-components/table/TableSizeContext";
|
||||
import { SvgEye, SvgXCircle } from "@opal/icons";
|
||||
|
||||
type SelectionState = "none" | "partial" | "all";
|
||||
|
||||
/**
|
||||
* Footer mode for tables with selectable rows.
|
||||
* Displays a selection message on the left (with optional view/clear actions)
|
||||
* and a `count`-type pagination on the right.
|
||||
*/
|
||||
interface FooterSelectionModeProps {
|
||||
mode: "selection";
|
||||
/** Whether the table supports selecting multiple rows. */
|
||||
multiSelect: boolean;
|
||||
/** Current selection state: `"none"`, `"partial"`, or `"all"`. */
|
||||
selectionState: SelectionState;
|
||||
/** Number of currently selected items. */
|
||||
selectedCount: number;
|
||||
/** If provided, renders a "View" icon button when items are selected. */
|
||||
onView?: () => void;
|
||||
/** If provided, renders a "Clear" icon button when items are selected. */
|
||||
onClear?: () => void;
|
||||
/** Number of items displayed per page. */
|
||||
pageSize: number;
|
||||
/** Total number of items across all pages. */
|
||||
totalItems: number;
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
|
||||
size?: TableSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Footer mode for read-only tables (no row selection).
|
||||
* Displays "Showing X~Y of Z" on the left and a `list`-type pagination
|
||||
* on the right.
|
||||
*/
|
||||
interface FooterSummaryModeProps {
|
||||
mode: "summary";
|
||||
/** First item number in the current page (e.g. `1`). */
|
||||
rangeStart: number;
|
||||
/** Last item number in the current page (e.g. `25`). */
|
||||
rangeEnd: number;
|
||||
/** Total number of items across all pages. */
|
||||
totalItems: number;
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
|
||||
size?: TableSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discriminated union of footer modes.
|
||||
* Use `mode: "selection"` for tables with selectable rows, or
|
||||
* `mode: "summary"` for read-only tables.
|
||||
*/
|
||||
export type FooterProps = FooterSelectionModeProps | FooterSummaryModeProps;
|
||||
|
||||
function getSelectionMessage(
|
||||
state: SelectionState,
|
||||
multi: boolean,
|
||||
count: number
|
||||
): string {
|
||||
if (state === "none") {
|
||||
return multi ? "Select items to continue" : "Select an item to continue";
|
||||
}
|
||||
if (!multi) return "Item selected";
|
||||
return `${count} item${count !== 1 ? "s" : ""} selected`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Table footer combining status information on the left with pagination on the
|
||||
* right. Use `mode: "selection"` for tables with selectable rows, or
|
||||
* `mode: "summary"` for read-only tables.
|
||||
*/
|
||||
export default function Footer(props: FooterProps) {
|
||||
const contextSize = useTableSize();
|
||||
const resolvedSize = props.size ?? contextSize;
|
||||
const isSmall = resolvedSize === "small";
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"table-footer",
|
||||
"flex w-full items-center justify-between border-t border-border-01",
|
||||
props.className
|
||||
)}
|
||||
data-size={resolvedSize}
|
||||
>
|
||||
{/* Left side */}
|
||||
<div className="flex items-center gap-1 px-1">
|
||||
{props.mode === "selection" ? (
|
||||
<SelectionLeft
|
||||
selectionState={props.selectionState}
|
||||
multiSelect={props.multiSelect}
|
||||
selectedCount={props.selectedCount}
|
||||
onView={props.onView}
|
||||
onClear={props.onClear}
|
||||
isSmall={isSmall}
|
||||
/>
|
||||
) : (
|
||||
<SummaryLeft
|
||||
rangeStart={props.rangeStart}
|
||||
rangeEnd={props.rangeEnd}
|
||||
totalItems={props.totalItems}
|
||||
isSmall={isSmall}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Right side */}
|
||||
<div className="flex items-center gap-2 px-1 py-2">
|
||||
{props.mode === "selection" ? (
|
||||
<Pagination
|
||||
type="count"
|
||||
pageSize={props.pageSize}
|
||||
totalItems={props.totalItems}
|
||||
currentPage={props.currentPage}
|
||||
totalPages={props.totalPages}
|
||||
onPageChange={props.onPageChange}
|
||||
showUnits
|
||||
size={isSmall ? "sm" : "md"}
|
||||
/>
|
||||
) : (
|
||||
<Pagination
|
||||
type="list"
|
||||
currentPage={props.currentPage}
|
||||
totalPages={props.totalPages}
|
||||
onPageChange={props.onPageChange}
|
||||
size={isSmall ? "md" : "lg"}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface SelectionLeftProps {
|
||||
selectionState: SelectionState;
|
||||
multiSelect: boolean;
|
||||
selectedCount: number;
|
||||
onView?: () => void;
|
||||
onClear?: () => void;
|
||||
isSmall: boolean;
|
||||
}
|
||||
|
||||
function SelectionLeft({
|
||||
selectionState,
|
||||
multiSelect,
|
||||
selectedCount,
|
||||
onView,
|
||||
onClear,
|
||||
isSmall,
|
||||
}: SelectionLeftProps) {
|
||||
const message = getSelectionMessage(
|
||||
selectionState,
|
||||
multiSelect,
|
||||
selectedCount
|
||||
);
|
||||
const hasSelection = selectionState !== "none";
|
||||
|
||||
return (
|
||||
<div className="flex flex-row gap-1 items-center justify-center w-fit flex-shrink-0 h-fit px-1">
|
||||
{isSmall ? (
|
||||
<Text
|
||||
secondaryAction={hasSelection}
|
||||
secondaryBody={!hasSelection}
|
||||
text03
|
||||
>
|
||||
{message}
|
||||
</Text>
|
||||
) : (
|
||||
<Text mainUiBody={hasSelection} mainUiMuted={!hasSelection} text03>
|
||||
{message}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{hasSelection && (
|
||||
<div className="flex flex-row items-center w-fit flex-shrink-0 h-fit">
|
||||
{onView && (
|
||||
<Button
|
||||
icon={SvgEye}
|
||||
onClick={onView}
|
||||
tooltip="View"
|
||||
size={isSmall ? "sm" : "md"}
|
||||
prominence="tertiary"
|
||||
/>
|
||||
)}
|
||||
{onClear && (
|
||||
<Button
|
||||
icon={SvgXCircle}
|
||||
onClick={onClear}
|
||||
tooltip="Clear selection"
|
||||
size={isSmall ? "sm" : "md"}
|
||||
prominence="tertiary"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface SummaryLeftProps {
|
||||
rangeStart: number;
|
||||
rangeEnd: number;
|
||||
totalItems: number;
|
||||
isSmall: boolean;
|
||||
}
|
||||
|
||||
function SummaryLeft({
|
||||
rangeStart,
|
||||
rangeEnd,
|
||||
totalItems,
|
||||
isSmall,
|
||||
}: SummaryLeftProps) {
|
||||
return (
|
||||
<div className="flex flex-row gap-1 items-center w-fit h-fit px-1">
|
||||
{isSmall ? (
|
||||
<Text secondaryBody text03>
|
||||
Showing{" "}
|
||||
<Text as="span" secondaryMono text03>
|
||||
{rangeStart}~{rangeEnd}
|
||||
</Text>{" "}
|
||||
of{" "}
|
||||
<Text as="span" secondaryMono text03>
|
||||
{totalItems}
|
||||
</Text>
|
||||
</Text>
|
||||
) : (
|
||||
<Text mainUiMuted text03>
|
||||
Showing{" "}
|
||||
<Text as="span" mainUiMono text03>
|
||||
{rangeStart}~{rangeEnd}
|
||||
</Text>{" "}
|
||||
of{" "}
|
||||
<Text as="span" mainUiMono text03>
|
||||
{totalItems}
|
||||
</Text>
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,393 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgChevronLeft, SvgChevronRight } from "@opal/icons";
|
||||
|
||||
type PaginationSize = "lg" | "md" | "sm";
|
||||
|
||||
/**
|
||||
* Minimal page navigation showing `currentPage / totalPages` with prev/next arrows.
|
||||
* Use when you only need simple forward/backward navigation.
|
||||
*/
|
||||
interface SimplePaginationProps {
|
||||
type: "simple";
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** When `true`, displays the word "pages" after the page indicator. */
|
||||
showUnits?: boolean;
|
||||
/** When `false`, hides the page indicator between the prev/next arrows. Defaults to `true`. */
|
||||
showPageIndicator?: boolean;
|
||||
/** Controls button and text sizing. Defaults to `"lg"`. */
|
||||
size?: PaginationSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Item-count pagination showing `currentItems of totalItems` with optional page
|
||||
* controls and a "Go to" button. Use inside table footers that need to communicate
|
||||
* how many items the user is viewing.
|
||||
*/
|
||||
interface CountPaginationProps {
|
||||
type: "count";
|
||||
/** Number of items displayed per page. Used to compute the visible range. */
|
||||
pageSize: number;
|
||||
/** Total number of items across all pages. */
|
||||
totalItems: number;
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** When `false`, hides the page number between the prev/next arrows (arrows still visible). Defaults to `true`. */
|
||||
showPageIndicator?: boolean;
|
||||
/** When `true`, renders a "Go to" button. Requires `onGoTo`. */
|
||||
showGoTo?: boolean;
|
||||
/** Callback invoked when the "Go to" button is clicked. */
|
||||
onGoTo?: () => void;
|
||||
/** When `true`, displays the word "items" after the total count. */
|
||||
showUnits?: boolean;
|
||||
/** Controls button and text sizing. Defaults to `"lg"`. */
|
||||
size?: PaginationSize;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Numbered page-list pagination with clickable page buttons and ellipsis
|
||||
* truncation for large page counts. Does not support `"sm"` size.
|
||||
*/
|
||||
interface ListPaginationProps {
|
||||
type: "list";
|
||||
/** The 1-based current page number. */
|
||||
currentPage: number;
|
||||
/** Total number of pages. */
|
||||
totalPages: number;
|
||||
/** Called when the user navigates to a different page. */
|
||||
onPageChange: (page: number) => void;
|
||||
/** When `false`, hides the page buttons between the prev/next arrows. Defaults to `true`. */
|
||||
showPageIndicator?: boolean;
|
||||
/** Controls button and text sizing. Defaults to `"lg"`. Only `"lg"` and `"md"` are supported. */
|
||||
size?: Exclude<PaginationSize, "sm">;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discriminated union of all pagination variants.
|
||||
* Use the `type` prop to select between `"simple"`, `"count"`, and `"list"`.
|
||||
*/
|
||||
export type PaginationProps =
|
||||
| SimplePaginationProps
|
||||
| CountPaginationProps
|
||||
| ListPaginationProps;
|
||||
|
||||
function getPageNumbers(currentPage: number, totalPages: number) {
|
||||
const pages: (number | string)[] = [];
|
||||
const maxPagesToShow = 7;
|
||||
|
||||
if (totalPages <= maxPagesToShow) {
|
||||
for (let i = 1; i <= totalPages; i++) {
|
||||
pages.push(i);
|
||||
}
|
||||
} else {
|
||||
pages.push(1);
|
||||
|
||||
let startPage = Math.max(2, currentPage - 1);
|
||||
let endPage = Math.min(totalPages - 1, currentPage + 1);
|
||||
|
||||
if (currentPage <= 3) {
|
||||
endPage = 5;
|
||||
} else if (currentPage >= totalPages - 2) {
|
||||
startPage = totalPages - 4;
|
||||
}
|
||||
|
||||
if (startPage > 2) {
|
||||
if (startPage === 3) {
|
||||
pages.push(2);
|
||||
} else {
|
||||
pages.push("start-ellipsis");
|
||||
}
|
||||
}
|
||||
|
||||
for (let i = startPage; i <= endPage; i++) {
|
||||
pages.push(i);
|
||||
}
|
||||
|
||||
if (endPage < totalPages - 1) {
|
||||
if (endPage === totalPages - 2) {
|
||||
pages.push(totalPages - 1);
|
||||
} else {
|
||||
pages.push("end-ellipsis");
|
||||
}
|
||||
}
|
||||
|
||||
pages.push(totalPages);
|
||||
}
|
||||
|
||||
return pages;
|
||||
}
|
||||
|
||||
function sizedTextProps(isSmall: boolean, variant: "mono" | "muted") {
|
||||
if (variant === "mono") {
|
||||
return isSmall ? { secondaryMono: true } : { mainUiMono: true };
|
||||
}
|
||||
return isSmall ? { secondaryBody: true } : { mainUiMuted: true };
|
||||
}
|
||||
|
||||
interface NavButtonsProps {
|
||||
currentPage: number;
|
||||
totalPages: number;
|
||||
onPageChange: (page: number) => void;
|
||||
size: PaginationSize;
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
|
||||
function NavButtons({
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
size,
|
||||
children,
|
||||
}: NavButtonsProps) {
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
icon={SvgChevronLeft}
|
||||
onClick={() => onPageChange(currentPage - 1)}
|
||||
disabled={currentPage <= 1}
|
||||
size={size}
|
||||
prominence="tertiary"
|
||||
tooltip="Previous page"
|
||||
/>
|
||||
{children}
|
||||
<Button
|
||||
icon={SvgChevronRight}
|
||||
onClick={() => onPageChange(currentPage + 1)}
|
||||
disabled={currentPage >= totalPages}
|
||||
size={size}
|
||||
prominence="tertiary"
|
||||
tooltip="Next page"
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Table pagination component with three variants: `simple`, `count`, and `list`.
|
||||
* Pass the `type` prop to select the variant, and the component will render the
|
||||
* appropriate UI.
|
||||
*/
|
||||
export default function Pagination(props: PaginationProps) {
|
||||
const normalized = { ...props, totalPages: Math.max(1, props.totalPages) };
|
||||
switch (normalized.type) {
|
||||
case "simple":
|
||||
return <SimplePaginationInner {...normalized} />;
|
||||
case "count":
|
||||
return <CountPaginationInner {...normalized} />;
|
||||
case "list":
|
||||
return <ListPaginationInner {...normalized} />;
|
||||
}
|
||||
}
|
||||
|
||||
function SimplePaginationInner({
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
showUnits,
|
||||
showPageIndicator = true,
|
||||
size = "lg",
|
||||
className,
|
||||
}: SimplePaginationProps) {
|
||||
const isSmall = size === "sm";
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center gap-1", className)}>
|
||||
<NavButtons
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={onPageChange}
|
||||
size={size}
|
||||
>
|
||||
{showPageIndicator && (
|
||||
<>
|
||||
<Text {...sizedTextProps(isSmall, "mono")} text03>
|
||||
{currentPage}
|
||||
<Text as="span" {...sizedTextProps(isSmall, "muted")} text03>
|
||||
/
|
||||
</Text>
|
||||
{totalPages}
|
||||
</Text>
|
||||
{showUnits && (
|
||||
<Text {...sizedTextProps(isSmall, "muted")} text03>
|
||||
pages
|
||||
</Text>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</NavButtons>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function CountPaginationInner({
|
||||
pageSize,
|
||||
totalItems,
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
showPageIndicator = true,
|
||||
showGoTo,
|
||||
onGoTo,
|
||||
showUnits,
|
||||
size = "lg",
|
||||
className,
|
||||
}: CountPaginationProps) {
|
||||
const isSmall = size === "sm";
|
||||
const rangeStart = totalItems === 0 ? 0 : (currentPage - 1) * pageSize + 1;
|
||||
const rangeEnd = Math.min(currentPage * pageSize, totalItems);
|
||||
const currentItems = `${rangeStart}~${rangeEnd}`;
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center gap-1", className)}>
|
||||
<Text {...sizedTextProps(isSmall, "mono")} text03>
|
||||
{currentItems}
|
||||
</Text>
|
||||
<Text {...sizedTextProps(isSmall, "muted")} text03>
|
||||
of
|
||||
</Text>
|
||||
<Text {...sizedTextProps(isSmall, "mono")} text03>
|
||||
{totalItems}
|
||||
</Text>
|
||||
{showUnits && (
|
||||
<Text {...sizedTextProps(isSmall, "muted")} text03>
|
||||
items
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<NavButtons
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={onPageChange}
|
||||
size={size}
|
||||
>
|
||||
{showPageIndicator && (
|
||||
<Text {...sizedTextProps(isSmall, "mono")} text03>
|
||||
{currentPage}
|
||||
</Text>
|
||||
)}
|
||||
</NavButtons>
|
||||
|
||||
{showGoTo && onGoTo && (
|
||||
<Button onClick={onGoTo} size={size} prominence="tertiary">
|
||||
Go to
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface PageNumberIconProps {
|
||||
className?: string;
|
||||
pageNum: number;
|
||||
isActive: boolean;
|
||||
isLarge: boolean;
|
||||
}
|
||||
|
||||
function PageNumberIcon({
|
||||
className: iconClassName,
|
||||
pageNum,
|
||||
isActive,
|
||||
isLarge,
|
||||
}: PageNumberIconProps) {
|
||||
return (
|
||||
<div className={cn(iconClassName, "flex flex-col justify-center")}>
|
||||
{isLarge ? (
|
||||
<Text
|
||||
mainUiBody={isActive}
|
||||
mainUiMuted={!isActive}
|
||||
text04={isActive}
|
||||
text02={!isActive}
|
||||
>
|
||||
{pageNum}
|
||||
</Text>
|
||||
) : (
|
||||
<Text
|
||||
secondaryAction={isActive}
|
||||
secondaryBody={!isActive}
|
||||
text04={isActive}
|
||||
text02={!isActive}
|
||||
>
|
||||
{pageNum}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ListPaginationInner({
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
showPageIndicator = true,
|
||||
size = "lg",
|
||||
className,
|
||||
}: ListPaginationProps) {
|
||||
const pageNumbers = getPageNumbers(currentPage, totalPages);
|
||||
const isLarge = size === "lg";
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center gap-1", className)}>
|
||||
<NavButtons
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={onPageChange}
|
||||
size={size}
|
||||
>
|
||||
{showPageIndicator && (
|
||||
<div className="flex items-center">
|
||||
{pageNumbers.map((page) => {
|
||||
if (typeof page === "string") {
|
||||
return (
|
||||
<Text
|
||||
key={page}
|
||||
mainUiMuted={isLarge}
|
||||
secondaryBody={!isLarge}
|
||||
text03
|
||||
>
|
||||
...
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
const pageNum = page as number;
|
||||
const isActive = pageNum === currentPage;
|
||||
|
||||
return (
|
||||
<Button
|
||||
key={pageNum}
|
||||
onClick={() => onPageChange(pageNum)}
|
||||
size={size}
|
||||
prominence="tertiary"
|
||||
transient={isActive}
|
||||
icon={({ className: iconClassName }) => (
|
||||
<PageNumberIcon
|
||||
className={iconClassName}
|
||||
pageNum={pageNum}
|
||||
isActive={isActive}
|
||||
isLarge={isLarge}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</NavButtons>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,317 +0,0 @@
|
||||
# DataTable
|
||||
|
||||
Config-driven table built on [TanStack Table](https://tanstack.com/table). Handles column sizing (weight-based proportional distribution), drag-and-drop row reordering, pagination, row selection, column visibility, and sorting out of the box.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```tsx
|
||||
import DataTable from "@/refresh-components/table/DataTable";
|
||||
import { createTableColumns } from "@/refresh-components/table/columns";
|
||||
|
||||
interface Person {
|
||||
name: string;
|
||||
email: string;
|
||||
role: string;
|
||||
}
|
||||
|
||||
// Define columns at module scope (stable reference, no re-renders)
|
||||
const tc = createTableColumns<Person>();
|
||||
const columns = [
|
||||
tc.qualifier(),
|
||||
tc.column("name", { header: "Name", weight: 30, minWidth: 120 }),
|
||||
tc.column("email", { header: "Email", weight: 40, minWidth: 150 }),
|
||||
tc.column("role", { header: "Role", weight: 30, minWidth: 80 }),
|
||||
tc.actions(),
|
||||
];
|
||||
|
||||
function PeopleTable({ data }: { data: Person[] }) {
|
||||
return (
|
||||
<DataTable
|
||||
data={data}
|
||||
columns={columns}
|
||||
pageSize={10}
|
||||
footer={{ mode: "selection" }}
|
||||
/>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
## Column Builder API
|
||||
|
||||
`createTableColumns<TData>()` returns a typed builder with four methods. Each returns an `OnyxColumnDef<TData>` that you pass to the `columns` prop.
|
||||
|
||||
### `tc.qualifier(config?)`
|
||||
|
||||
Leading column for avatars, icons, images, or checkboxes.
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `content` | `"simple" \| "icon" \| "image" \| "avatar-icon" \| "avatar-user"` | `"simple"` | Body row content type |
|
||||
| `headerContentType` | same as `content` | `"simple"` | Header row content type |
|
||||
| `getInitials` | `(row: TData) => string` | - | Extract initials (for `"avatar-user"`) |
|
||||
| `getIcon` | `(row: TData) => IconFunctionComponent` | - | Extract icon (for `"icon"` / `"avatar-icon"`) |
|
||||
| `getImageSrc` | `(row: TData) => string` | - | Extract image src (for `"image"`) |
|
||||
| `selectable` | `boolean` | `true` | Show selection checkboxes |
|
||||
| `header` | `boolean` | `true` | Render qualifier content in the header |
|
||||
|
||||
Width is fixed: 56px at `"regular"` size, 40px at `"small"`.
|
||||
|
||||
```ts
|
||||
tc.qualifier({
|
||||
content: "avatar-user",
|
||||
getInitials: (row) => row.initials,
|
||||
})
|
||||
```
|
||||
|
||||
### `tc.column(accessor, config)`
|
||||
|
||||
Data column with sorting, resizing, and hiding. The `accessor` is a type-safe deep key into `TData`.
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `header` | `string` | **required** | Column header label |
|
||||
| `cell` | `(value: TValue, row: TData) => ReactNode` | renders value as string | Custom cell renderer |
|
||||
| `enableSorting` | `boolean` | `true` | Allow sorting |
|
||||
| `enableResizing` | `boolean` | `true` | Allow column resize |
|
||||
| `enableHiding` | `boolean` | `true` | Allow hiding via actions popover |
|
||||
| `icon` | `(sorted: SortDirection) => IconFunctionComponent` | - | Override the sort indicator icon |
|
||||
| `weight` | `number` | `20` | Proportional width weight |
|
||||
| `minWidth` | `number` | `50` | Minimum width in pixels |
|
||||
|
||||
```ts
|
||||
tc.column("email", {
|
||||
header: "Email",
|
||||
weight: 28,
|
||||
minWidth: 150,
|
||||
cell: (value) => <Content sizePreset="main-ui" variant="body" title={value} prominence="muted" />,
|
||||
})
|
||||
```
|
||||
|
||||
### `tc.displayColumn(config)`
|
||||
|
||||
Non-accessor column for custom content (e.g. computed values, action buttons per row).
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `id` | `string` | **required** | Unique column ID |
|
||||
| `header` | `string` | - | Optional header label |
|
||||
| `cell` | `(row: TData) => ReactNode` | **required** | Cell renderer |
|
||||
| `width` | `ColumnWidth` | **required** | `{ weight, minWidth? }` or `{ fixed }` |
|
||||
| `enableHiding` | `boolean` | `true` | Allow hiding |
|
||||
|
||||
```ts
|
||||
tc.displayColumn({
|
||||
id: "fullName",
|
||||
header: "Full Name",
|
||||
cell: (row) => `${row.firstName} ${row.lastName}`,
|
||||
width: { weight: 25, minWidth: 100 },
|
||||
})
|
||||
```
|
||||
|
||||
### `tc.actions(config?)`
|
||||
|
||||
Fixed-width column rendered at the trailing edge. Houses column visibility and sorting popovers in the header.
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `showColumnVisibility` | `boolean` | `true` | Show the column visibility popover |
|
||||
| `showSorting` | `boolean` | `true` | Show the sorting popover |
|
||||
| `sortingFooterText` | `string` | - | Footer text inside the sorting popover |
|
||||
|
||||
Width is fixed: 88px at `"regular"`, 20px at `"small"`.
|
||||
|
||||
```ts
|
||||
tc.actions({
|
||||
sortingFooterText: "Everyone will see agents in this order.",
|
||||
})
|
||||
```
|
||||
|
||||
## DataTable Props
|
||||
|
||||
`DataTableProps<TData>`:
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `data` | `TData[]` | **required** | Row data |
|
||||
| `columns` | `OnyxColumnDef<TData>[]` | **required** | Columns from `createTableColumns()` |
|
||||
| `pageSize` | `number` | `10` (with footer) or `data.length` (without) | Rows per page. `Infinity` disables pagination |
|
||||
| `initialSorting` | `SortingState` | `[]` | TanStack sorting state |
|
||||
| `initialColumnVisibility` | `VisibilityState` | `{}` | Map of column ID to `false` to hide initially |
|
||||
| `draggable` | `DataTableDraggableConfig<TData>` | - | Enable drag-and-drop (see below) |
|
||||
| `footer` | `DataTableFooterConfig` | - | Footer mode (see below) |
|
||||
| `size` | `"regular" \| "small"` | `"regular"` | Table density variant |
|
||||
| `onRowClick` | `(row: TData) => void` | toggles selection | Called on row click, replaces default selection toggle |
|
||||
| `height` | `number \| string` | - | Max height for scrollable body (header stays pinned). `300` or `"50vh"` |
|
||||
| `headerBackground` | `string` | - | CSS color for the sticky header (prevents content showing through) |
|
||||
|
||||
## Footer Config
|
||||
|
||||
The `footer` prop accepts a discriminated union on `mode`.
|
||||
|
||||
### Selection mode
|
||||
|
||||
For tables with selectable rows. Shows a selection message + count pagination.
|
||||
|
||||
```ts
|
||||
footer={{
|
||||
mode: "selection",
|
||||
multiSelect: true, // default true
|
||||
onView: () => { ... }, // optional "View" button
|
||||
onClear: () => { ... }, // optional "Clear" button (falls back to default clearSelection)
|
||||
}}
|
||||
```
|
||||
|
||||
### Summary mode
|
||||
|
||||
For read-only tables. Shows "Showing X~Y of Z" + list pagination.
|
||||
|
||||
```ts
|
||||
footer={{ mode: "summary" }}
|
||||
```
|
||||
|
||||
## Draggable Config
|
||||
|
||||
Enable drag-and-drop row reordering. DnD is automatically disabled when column sorting is active.
|
||||
|
||||
```ts
|
||||
<DataTable
|
||||
data={items}
|
||||
columns={columns}
|
||||
draggable={{
|
||||
getRowId: (row) => row.id,
|
||||
onReorder: (ids, changedOrders) => {
|
||||
// ids: new ordered array of all row IDs
|
||||
// changedOrders: { [id]: newIndex } for rows that moved
|
||||
setItems(ids.map((id) => items.find((r) => r.id === id)!));
|
||||
},
|
||||
}}
|
||||
/>
|
||||
```
|
||||
|
||||
| Option | Type | Description |
|
||||
|---|---|---|
|
||||
| `getRowId` | `(row: TData) => string` | Extract a unique string ID from each row |
|
||||
| `onReorder` | `(ids: string[], changedOrders: Record<string, number>) => void \| Promise<void>` | Called after a successful reorder |
|
||||
|
||||
## Sizing
|
||||
|
||||
The `size` prop (`"regular"` or `"small"`) affects:
|
||||
|
||||
- Qualifier column width (56px vs 40px)
|
||||
- Actions column width (88px vs 20px)
|
||||
- Footer text styles and pagination size
|
||||
- All child components via `TableSizeContext`
|
||||
|
||||
Column widths can be responsive to size using a function:
|
||||
|
||||
```ts
|
||||
// In types.ts, width accepts:
|
||||
width: ColumnWidth | ((size: TableSize) => ColumnWidth)
|
||||
|
||||
// Example (this is what qualifier/actions use internally):
|
||||
width: (size) => size === "small" ? { fixed: 40 } : { fixed: 56 }
|
||||
```
|
||||
|
||||
### Width system
|
||||
|
||||
Data columns use **weight-based proportional distribution**. A column with `weight: 40` gets twice the space of one with `weight: 20`. When the container is narrower than the sum of `minWidth` values, columns clamp to their minimums.
|
||||
|
||||
Fixed columns (`{ fixed: N }`) take exactly N pixels and don't participate in proportional distribution.
|
||||
|
||||
Resizing uses **splitter semantics**: dragging a column border grows that column and shrinks its neighbor by the same amount, keeping total width constant.
|
||||
|
||||
## Advanced Examples
|
||||
|
||||
### Scrollable table with pinned header
|
||||
|
||||
```tsx
|
||||
<DataTable
|
||||
data={allRows}
|
||||
columns={columns}
|
||||
height={300}
|
||||
headerBackground="var(--background-tint-00)"
|
||||
/>
|
||||
```
|
||||
|
||||
### Hidden columns on load
|
||||
|
||||
```tsx
|
||||
<DataTable
|
||||
data={data}
|
||||
columns={columns}
|
||||
initialColumnVisibility={{ department: false, joinDate: false }}
|
||||
footer={{ mode: "selection" }}
|
||||
/>
|
||||
```
|
||||
|
||||
### Icon-based data column
|
||||
|
||||
```tsx
|
||||
const STATUS_ICONS = {
|
||||
active: SvgCheckCircle,
|
||||
pending: SvgClock,
|
||||
inactive: SvgAlertCircle,
|
||||
} as const;
|
||||
|
||||
tc.column("status", {
|
||||
header: "Status",
|
||||
weight: 14,
|
||||
minWidth: 80,
|
||||
cell: (value) => (
|
||||
<Content
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
icon={STATUS_ICONS[value]}
|
||||
title={value.charAt(0).toUpperCase() + value.slice(1)}
|
||||
/>
|
||||
),
|
||||
})
|
||||
```
|
||||
|
||||
### Non-selectable qualifier with icons
|
||||
|
||||
```ts
|
||||
tc.qualifier({
|
||||
content: "icon",
|
||||
getIcon: (row) => row.icon,
|
||||
selectable: false,
|
||||
header: false,
|
||||
})
|
||||
```
|
||||
|
||||
### Small variant in a bordered container
|
||||
|
||||
```tsx
|
||||
<div className="border border-border-01 rounded-lg overflow-hidden">
|
||||
<DataTable
|
||||
data={data}
|
||||
columns={columns}
|
||||
size="small"
|
||||
pageSize={10}
|
||||
footer={{ mode: "selection" }}
|
||||
/>
|
||||
</div>
|
||||
```
|
||||
|
||||
### Custom row click handler
|
||||
|
||||
```tsx
|
||||
<DataTable
|
||||
data={data}
|
||||
columns={columns}
|
||||
onRowClick={(row) => router.push(`/users/${row.id}`)}
|
||||
/>
|
||||
```
|
||||
|
||||
## Source Files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `DataTable.tsx` | Main component |
|
||||
| `columns.ts` | `createTableColumns` builder |
|
||||
| `types.ts` | All TypeScript interfaces |
|
||||
| `hooks/useDataTable.ts` | TanStack table wrapper hook |
|
||||
| `hooks/useColumnWidths.ts` | Weight-based width system |
|
||||
| `hooks/useDraggableRows.ts` | DnD hook (`@dnd-kit`) |
|
||||
| `Footer.tsx` | Selection / Summary footer modes |
|
||||
| `TableSizeContext.tsx` | Size context provider |
|
||||
@@ -1,181 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import {
|
||||
type Table,
|
||||
type ColumnDef,
|
||||
type RowData,
|
||||
type SortingState,
|
||||
} from "@tanstack/react-table";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgArrowUpDown, SvgSortOrder, SvgCheck } from "@opal/icons";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import Divider from "@/refresh-components/Divider";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Popover UI
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface SortingPopoverProps<TData extends RowData = RowData> {
|
||||
table: Table<TData>;
|
||||
sorting: SortingState;
|
||||
size?: "regular" | "small";
|
||||
footerText?: string;
|
||||
ascendingLabel?: string;
|
||||
descendingLabel?: string;
|
||||
}
|
||||
|
||||
function SortingPopover<TData extends RowData>({
|
||||
table,
|
||||
sorting,
|
||||
size = "regular",
|
||||
footerText,
|
||||
ascendingLabel = "Ascending",
|
||||
descendingLabel = "Descending",
|
||||
}: SortingPopoverProps<TData>) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const sortableColumns = table
|
||||
.getAllLeafColumns()
|
||||
.filter((col) => col.getCanSort());
|
||||
|
||||
const currentSort = sorting[0] ?? null;
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<Button
|
||||
icon={currentSort === null ? SvgArrowUpDown : SvgSortOrder}
|
||||
transient={open}
|
||||
size={size === "small" ? "sm" : "md"}
|
||||
prominence="internal"
|
||||
tooltip="Sort"
|
||||
/>
|
||||
</Popover.Trigger>
|
||||
|
||||
<Popover.Content width="lg" align="end" side="bottom">
|
||||
<Popover.Menu
|
||||
footer={
|
||||
footerText ? (
|
||||
<div className="px-2 py-1">
|
||||
<Text secondaryBody text03>
|
||||
{footerText}
|
||||
</Text>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
>
|
||||
<Divider showTitle text="Sort by" />
|
||||
|
||||
<LineItem
|
||||
selected={currentSort === null}
|
||||
emphasized
|
||||
rightChildren={
|
||||
currentSort === null ? <SvgCheck size={16} /> : undefined
|
||||
}
|
||||
onClick={() => {
|
||||
table.resetSorting();
|
||||
}}
|
||||
>
|
||||
Manual Ordering
|
||||
</LineItem>
|
||||
|
||||
{sortableColumns.map((column) => {
|
||||
const isSorted = currentSort?.id === column.id;
|
||||
const label =
|
||||
typeof column.columnDef.header === "string"
|
||||
? column.columnDef.header
|
||||
: column.id;
|
||||
|
||||
return (
|
||||
<LineItem
|
||||
key={column.id}
|
||||
selected={isSorted}
|
||||
emphasized
|
||||
rightChildren={isSorted ? <SvgCheck size={16} /> : undefined}
|
||||
onClick={() => {
|
||||
if (isSorted) {
|
||||
table.resetSorting();
|
||||
return;
|
||||
}
|
||||
column.toggleSorting(false);
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
</LineItem>
|
||||
);
|
||||
})}
|
||||
|
||||
{currentSort !== null && (
|
||||
<>
|
||||
<Divider showTitle text="Sorting Order" />
|
||||
|
||||
<LineItem
|
||||
selected={!currentSort.desc}
|
||||
emphasized
|
||||
rightChildren={
|
||||
!currentSort.desc ? <SvgCheck size={16} /> : undefined
|
||||
}
|
||||
onClick={() => {
|
||||
table.setSorting([{ id: currentSort.id, desc: false }]);
|
||||
}}
|
||||
>
|
||||
{ascendingLabel}
|
||||
</LineItem>
|
||||
|
||||
<LineItem
|
||||
selected={currentSort.desc}
|
||||
emphasized
|
||||
rightChildren={
|
||||
currentSort.desc ? <SvgCheck size={16} /> : undefined
|
||||
}
|
||||
onClick={() => {
|
||||
table.setSorting([{ id: currentSort.id, desc: true }]);
|
||||
}}
|
||||
>
|
||||
{descendingLabel}
|
||||
</LineItem>
|
||||
</>
|
||||
)}
|
||||
</Popover.Menu>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Column definition factory
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CreateSortingColumnOptions {
|
||||
size?: "regular" | "small";
|
||||
footerText?: string;
|
||||
ascendingLabel?: string;
|
||||
descendingLabel?: string;
|
||||
}
|
||||
|
||||
function createSortingColumn<TData>(
|
||||
options?: CreateSortingColumnOptions
|
||||
): ColumnDef<TData, unknown> {
|
||||
return {
|
||||
id: "__sorting",
|
||||
size: 44,
|
||||
enableHiding: false,
|
||||
enableSorting: false,
|
||||
enableResizing: false,
|
||||
header: ({ table }) => (
|
||||
<SortingPopover
|
||||
table={table}
|
||||
sorting={table.getState().sorting}
|
||||
size={options?.size}
|
||||
footerText={options?.footerText}
|
||||
ascendingLabel={options?.ascendingLabel}
|
||||
descendingLabel={options?.descendingLabel}
|
||||
/>
|
||||
),
|
||||
cell: () => null,
|
||||
};
|
||||
}
|
||||
|
||||
export { SortingPopover, createSortingColumn };
|
||||
@@ -830,8 +830,12 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
toggleDeepResearch={toggleDeepResearch}
|
||||
toggleDocumentSidebar={toggleDocumentSidebar}
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
removeDocs={() => setSelectedDocuments([])}
|
||||
retrievalEnabled={retrievalEnabled}
|
||||
selectedDocuments={selectedDocuments}
|
||||
initialMessage={
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.USER_PROMPT) ||
|
||||
""
|
||||
|
||||
@@ -173,21 +173,19 @@ export function FileCard({
|
||||
removeFile && doneUploading ? () => removeFile(file.id) : undefined
|
||||
}
|
||||
>
|
||||
<div className="min-w-0 max-w-[12rem]">
|
||||
<div className="max-w-[12rem]">
|
||||
<Interactive.Container border heightVariant="fit">
|
||||
<div className="[&_.opal-content-md-body]:min-w-0 [&_.opal-content-md-title]:break-all">
|
||||
<AttachmentItemLayout
|
||||
icon={isProcessing ? SimpleLoader : SvgFileText}
|
||||
title={file.name}
|
||||
description={
|
||||
isProcessing
|
||||
? file.status === UserFileStatus.UPLOADING
|
||||
? "Uploading..."
|
||||
: "Processing..."
|
||||
: typeLabel
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<AttachmentItemLayout
|
||||
icon={isProcessing ? SimpleLoader : SvgFileText}
|
||||
title={file.name}
|
||||
description={
|
||||
isProcessing
|
||||
? file.status === UserFileStatus.UPLOADING
|
||||
? "Uploading..."
|
||||
: "Processing..."
|
||||
: typeLabel
|
||||
}
|
||||
/>
|
||||
<Spacer horizontal rem={0.5} />
|
||||
</Interactive.Container>
|
||||
</div>
|
||||
|
||||
@@ -16,18 +16,16 @@ import { FilterManager, LlmManager, useFederatedConnectors } from "@/lib/hooks";
|
||||
import usePromptShortcuts from "@/hooks/usePromptShortcuts";
|
||||
import useFilter from "@/hooks/useFilter";
|
||||
import useCCPairs from "@/hooks/useCCPairs";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { OnyxDocument, MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { ChatState } from "@/app/app/interfaces";
|
||||
import { useForcedTools } from "@/lib/hooks/useForcedTools";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { cn, isImageFile } from "@/lib/utils";
|
||||
import { getFormattedDateRangeString } from "@/lib/dateUtils";
|
||||
import { truncateString, cn, isImageFile } from "@/lib/utils";
|
||||
import { Disabled } from "@/refresh-components/Disabled";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import {
|
||||
SettingsContext,
|
||||
useVectorDbEnabled,
|
||||
} from "@/providers/SettingsProvider";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import { useProjectsContext } from "@/providers/ProjectsContext";
|
||||
import { FileCard } from "@/sections/cards/FileCard";
|
||||
import {
|
||||
@@ -42,6 +40,9 @@ import {
|
||||
} from "@/app/app/services/actionUtils";
|
||||
import {
|
||||
SvgArrowUp,
|
||||
SvgCalendar,
|
||||
SvgFiles,
|
||||
SvgFileText,
|
||||
SvgGlobe,
|
||||
SvgHourglass,
|
||||
SvgPlus,
|
||||
@@ -50,22 +51,64 @@ import {
|
||||
SvgStop,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { Button, OpenButton } from "@opal/components";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
|
||||
const LINE_HEIGHT = 24;
|
||||
const MIN_INPUT_HEIGHT = 44;
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
export interface SourceChipProps {
|
||||
icon?: React.ReactNode;
|
||||
title: string;
|
||||
onRemove?: () => void;
|
||||
onClick?: () => void;
|
||||
truncateTitle?: boolean;
|
||||
}
|
||||
|
||||
export function SourceChip({
|
||||
icon,
|
||||
title,
|
||||
onRemove,
|
||||
onClick,
|
||||
truncateTitle = true,
|
||||
}: SourceChipProps) {
|
||||
return (
|
||||
<div
|
||||
onClick={onClick ? onClick : undefined}
|
||||
className={cn(
|
||||
"flex-none flex items-center px-1 bg-background-neutral-01 text-xs text-text-04 border border-border-01 rounded-08 box-border gap-x-1 h-6",
|
||||
onClick && "cursor-pointer"
|
||||
)}
|
||||
>
|
||||
{icon}
|
||||
{truncateTitle ? truncateString(title, 20) : title}
|
||||
{onRemove && (
|
||||
<SvgX
|
||||
size={12}
|
||||
className="text-text-01 ml-auto cursor-pointer"
|
||||
onClick={(e: React.MouseEvent<SVGSVGElement>) => {
|
||||
e.stopPropagation();
|
||||
onRemove();
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export interface AppInputBarHandle {
|
||||
reset: () => void;
|
||||
focus: () => void;
|
||||
}
|
||||
|
||||
export interface AppInputBarProps {
|
||||
removeDocs: () => void;
|
||||
selectedDocuments: OnyxDocument[];
|
||||
initialMessage?: string;
|
||||
stopGenerating: () => void;
|
||||
onSubmit: (message: string) => void;
|
||||
@@ -77,8 +120,10 @@ export interface AppInputBarProps {
|
||||
// agents
|
||||
selectedAgent: MinimalPersonaSnapshot | undefined;
|
||||
|
||||
toggleDocumentSidebar: () => void;
|
||||
handleFileUpload: (files: File[]) => void;
|
||||
filterManager: FilterManager;
|
||||
retrievalEnabled: boolean;
|
||||
deepResearchEnabled: boolean;
|
||||
setPresentingDocument?: (document: MinimalOnyxDocument) => void;
|
||||
toggleDeepResearch: () => void;
|
||||
@@ -92,13 +137,18 @@ export interface AppInputBarProps {
|
||||
|
||||
const AppInputBar = React.memo(
|
||||
({
|
||||
retrievalEnabled,
|
||||
removeDocs,
|
||||
toggleDocumentSidebar,
|
||||
filterManager,
|
||||
selectedDocuments,
|
||||
initialMessage = "",
|
||||
stopGenerating,
|
||||
onSubmit,
|
||||
chatState,
|
||||
currentSessionFileTokenCount,
|
||||
availableContextTokens,
|
||||
// agents
|
||||
selectedAgent,
|
||||
|
||||
handleFileUpload,
|
||||
@@ -115,9 +165,6 @@ const AppInputBar = React.memo(
|
||||
// Internal message state - kept local to avoid parent re-renders on every keystroke
|
||||
const [message, setMessage] = useState(initialMessage);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const textAreaWrapperRef = useRef<HTMLDivElement>(null);
|
||||
const filesWrapperRef = useRef<HTMLDivElement>(null);
|
||||
const filesContentRef = useRef<HTMLDivElement>(null);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const { user } = useUser();
|
||||
const { isClassifying, classification } = useQueryController();
|
||||
@@ -131,16 +178,6 @@ const AppInputBar = React.memo(
|
||||
textAreaRef.current?.focus();
|
||||
},
|
||||
}));
|
||||
|
||||
// Sync non-empty prop changes to internal state (e.g. NRFPage reads URL params
|
||||
// after mount). Intentionally skips empty strings — clearing is handled via the
|
||||
// imperative ref.reset() method, not by passing initialMessage="".
|
||||
useEffect(() => {
|
||||
if (initialMessage) {
|
||||
setMessage(initialMessage);
|
||||
}
|
||||
}, [initialMessage]);
|
||||
|
||||
const { appMode } = useAppMode();
|
||||
const appFocus = useAppFocus();
|
||||
const isSearchMode =
|
||||
@@ -190,39 +227,46 @@ const AppInputBar = React.memo(
|
||||
|
||||
const combinedSettings = useContext(SettingsContext);
|
||||
|
||||
// TODO(@raunakab): Replace this useEffect with CSS `field-sizing: content` once
|
||||
// Firefox ships it unflagged (currently behind `layout.css.field-sizing.enabled`).
|
||||
// Auto-resize textarea based on content (chat mode only).
|
||||
// Reset to min-height first so scrollHeight reflects actual content size,
|
||||
// then clamp between min and max. This handles both growing and shrinking.
|
||||
useEffect(() => {
|
||||
const wrapper = textAreaWrapperRef.current;
|
||||
const textarea = textAreaRef.current;
|
||||
if (!wrapper || !textarea) return;
|
||||
// Track previous message to detect when lines might decrease
|
||||
const prevMessageRef = useRef("");
|
||||
|
||||
wrapper.style.height = `${MIN_INPUT_HEIGHT}px`;
|
||||
wrapper.style.height = `${Math.min(
|
||||
Math.max(textarea.scrollHeight, MIN_INPUT_HEIGHT),
|
||||
MAX_INPUT_HEIGHT
|
||||
)}px`;
|
||||
// Auto-resize textarea based on content
|
||||
useEffect(() => {
|
||||
if (isSearchMode) return;
|
||||
const textarea = textAreaRef.current;
|
||||
if (textarea) {
|
||||
const prevLineCount = (prevMessageRef.current.match(/\n/g) || [])
|
||||
.length;
|
||||
const currLineCount = (message.match(/\n/g) || []).length;
|
||||
const lineRemoved = currLineCount < prevLineCount;
|
||||
prevMessageRef.current = message;
|
||||
|
||||
if (message.length === 0) {
|
||||
textarea.style.height = `${MIN_INPUT_HEIGHT}px`;
|
||||
return;
|
||||
} else if (lineRemoved) {
|
||||
const linesRemoved = prevLineCount - currLineCount;
|
||||
textarea.style.height = `${Math.max(
|
||||
MIN_INPUT_HEIGHT,
|
||||
Math.min(
|
||||
textarea.scrollHeight - LINE_HEIGHT * linesRemoved,
|
||||
MAX_INPUT_HEIGHT
|
||||
)
|
||||
)}px`;
|
||||
} else {
|
||||
textarea.style.height = `${Math.min(
|
||||
textarea.scrollHeight,
|
||||
MAX_INPUT_HEIGHT
|
||||
)}px`;
|
||||
}
|
||||
}
|
||||
}, [message, isSearchMode]);
|
||||
|
||||
// Animate attached files wrapper to its content height so CSS transitions
|
||||
// can interpolate between concrete pixel values (0px ↔ Npx).
|
||||
const showFiles = !isSearchMode && currentMessageFiles.length > 0;
|
||||
useEffect(() => {
|
||||
const wrapper = filesWrapperRef.current;
|
||||
const content = filesContentRef.current;
|
||||
if (!wrapper || !content) return;
|
||||
|
||||
if (showFiles) {
|
||||
// Measure the inner content's actual height, then add padding (p-1 = 8px total)
|
||||
const PADDING = 8;
|
||||
wrapper.style.height = `${content.offsetHeight + PADDING}px`;
|
||||
} else {
|
||||
wrapper.style.height = "0px";
|
||||
if (initialMessage) {
|
||||
setMessage(initialMessage);
|
||||
}
|
||||
}, [showFiles, currentMessageFiles]);
|
||||
}, [initialMessage]);
|
||||
|
||||
function handlePaste(event: React.ClipboardEvent) {
|
||||
const items = event.clipboardData?.items;
|
||||
@@ -250,7 +294,8 @@ const AppInputBar = React.memo(
|
||||
);
|
||||
|
||||
const { activePromptShortcuts } = usePromptShortcuts();
|
||||
const vectorDbEnabled = useVectorDbEnabled();
|
||||
const vectorDbEnabled =
|
||||
combinedSettings?.settings.vector_db_enabled !== false;
|
||||
const { ccPairs, isLoading: ccPairsLoading } = useCCPairs(vectorDbEnabled);
|
||||
const { data: federatedConnectorsData, isLoading: federatedLoading } =
|
||||
useFederatedConnectors();
|
||||
@@ -367,9 +412,7 @@ const AppInputBar = React.memo(
|
||||
combinedSettings?.settings?.deep_research_enabled,
|
||||
]);
|
||||
|
||||
function handleKeyDownForPromptShortcuts(
|
||||
e: React.KeyboardEvent<HTMLTextAreaElement>
|
||||
) {
|
||||
function handleKeyDown(e: React.KeyboardEvent<HTMLTextAreaElement>) {
|
||||
if (!user?.preferences?.shortcut_enabled || !showPrompts) return;
|
||||
|
||||
if (e.key === "Enter") {
|
||||
@@ -404,171 +447,6 @@ const AppInputBar = React.memo(
|
||||
}
|
||||
}
|
||||
|
||||
const chatControls = (
|
||||
<div
|
||||
{...(isSearchMode ? { inert: true } : {})}
|
||||
className={cn(
|
||||
"flex justify-between items-center w-full",
|
||||
isSearchMode
|
||||
? "opacity-0 p-0 h-0 overflow-hidden pointer-events-none"
|
||||
: "opacity-100 p-1 h-[2.75rem] pointer-events-auto",
|
||||
"transition-all duration-150"
|
||||
)}
|
||||
>
|
||||
{/* Bottom left controls */}
|
||||
<div className="flex flex-row items-center">
|
||||
{/* (+) button - always visible */}
|
||||
<FilePickerPopover
|
||||
onFileClick={handleFileClick}
|
||||
onPickRecent={(file: ProjectFile) => {
|
||||
// Check if file with same ID already exists
|
||||
if (
|
||||
!currentMessageFiles.some(
|
||||
(existingFile) => existingFile.file_id === file.file_id
|
||||
)
|
||||
) {
|
||||
setCurrentMessageFiles((prev) => [...prev, file]);
|
||||
}
|
||||
}}
|
||||
onUnpickRecent={(file: ProjectFile) => {
|
||||
setCurrentMessageFiles((prev) =>
|
||||
prev.filter(
|
||||
(existingFile) => existingFile.file_id !== file.file_id
|
||||
)
|
||||
);
|
||||
}}
|
||||
handleUploadChange={handleUploadChange}
|
||||
trigger={(open) => (
|
||||
<Button
|
||||
icon={SvgPlusCircle}
|
||||
tooltip="Attach Files"
|
||||
transient={open}
|
||||
disabled={disabled}
|
||||
prominence="tertiary"
|
||||
/>
|
||||
)}
|
||||
selectedFileIds={currentMessageFiles.map((f) => f.id)}
|
||||
/>
|
||||
|
||||
{/* Controls that load in when data is ready */}
|
||||
<div
|
||||
data-testid="actions-container"
|
||||
className={cn(
|
||||
"flex flex-row items-center",
|
||||
controlsLoading && "invisible"
|
||||
)}
|
||||
>
|
||||
{selectedAgent && selectedAgent.tools.length > 0 && (
|
||||
<ActionsPopover
|
||||
selectedAgent={selectedAgent}
|
||||
filterManager={filterManager}
|
||||
availableSources={memoizedAvailableSources}
|
||||
disabled={disabled}
|
||||
/>
|
||||
)}
|
||||
{onToggleTabReading ? (
|
||||
<Button
|
||||
icon={SvgGlobe}
|
||||
onClick={onToggleTabReading}
|
||||
variant="select"
|
||||
selected={tabReadingEnabled}
|
||||
foldable={!tabReadingEnabled}
|
||||
disabled={disabled}
|
||||
>
|
||||
{tabReadingEnabled
|
||||
? currentTabUrl
|
||||
? (() => {
|
||||
try {
|
||||
return new URL(currentTabUrl).hostname;
|
||||
} catch {
|
||||
return currentTabUrl;
|
||||
}
|
||||
})()
|
||||
: "Reading tab..."
|
||||
: "Read this tab"}
|
||||
</Button>
|
||||
) : (
|
||||
showDeepResearch && (
|
||||
<Button
|
||||
icon={SvgHourglass}
|
||||
onClick={toggleDeepResearch}
|
||||
variant="select"
|
||||
selected={deepResearchEnabled}
|
||||
foldable={!deepResearchEnabled}
|
||||
disabled={disabled}
|
||||
>
|
||||
Deep Research
|
||||
</Button>
|
||||
)
|
||||
)}
|
||||
|
||||
{selectedAgent &&
|
||||
forcedToolIds.length > 0 &&
|
||||
forcedToolIds.map((toolId) => {
|
||||
const tool = selectedAgent.tools.find(
|
||||
(tool) => tool.id === toolId
|
||||
);
|
||||
if (!tool) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<Button
|
||||
key={toolId}
|
||||
icon={getIconForAction(tool)}
|
||||
onClick={() => {
|
||||
setForcedToolIds(
|
||||
forcedToolIds.filter((id) => id !== toolId)
|
||||
);
|
||||
}}
|
||||
variant="select"
|
||||
selected
|
||||
disabled={disabled}
|
||||
>
|
||||
{tool.display_name}
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Bottom right controls */}
|
||||
<div className="flex flex-row items-center gap-1">
|
||||
<div
|
||||
data-testid="AppInputBar/llm-popover-trigger"
|
||||
className={cn(controlsLoading && "invisible")}
|
||||
>
|
||||
<LLMPopover
|
||||
llmManager={llmManager}
|
||||
requiresImageInput={hasImageFiles}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
id="onyx-chat-input-send-button"
|
||||
icon={
|
||||
isClassifying
|
||||
? SimpleLoader
|
||||
: chatState === "input"
|
||||
? SvgArrowUp
|
||||
: SvgStop
|
||||
}
|
||||
disabled={
|
||||
(chatState === "input" && !message) ||
|
||||
hasUploadingFiles ||
|
||||
isClassifying
|
||||
}
|
||||
onClick={() => {
|
||||
if (chatState == "streaming") {
|
||||
stopGenerating();
|
||||
} else if (message) {
|
||||
onSubmit(message);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<Disabled disabled={disabled} allowClick>
|
||||
<div
|
||||
@@ -589,17 +467,8 @@ const AppInputBar = React.memo(
|
||||
)}
|
||||
>
|
||||
{/* Attached Files */}
|
||||
<div
|
||||
ref={filesWrapperRef}
|
||||
{...(!showFiles ? { inert: true } : {})}
|
||||
className={cn(
|
||||
"transition-all duration-150",
|
||||
showFiles
|
||||
? "opacity-100 p-1"
|
||||
: "opacity-0 p-0 overflow-hidden pointer-events-none"
|
||||
)}
|
||||
>
|
||||
<div ref={filesContentRef} className="flex flex-wrap gap-1">
|
||||
{currentMessageFiles.length > 0 && (
|
||||
<div className="p-2 rounded-t-16 flex flex-wrap gap-1">
|
||||
{currentMessageFiles.map((file) => (
|
||||
<FileCard
|
||||
key={file.id}
|
||||
@@ -611,61 +480,76 @@ const AppInputBar = React.memo(
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-row items-center w-full">
|
||||
{/* Input area */}
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-row items-center w-full",
|
||||
isSearchMode && "p-1"
|
||||
)}
|
||||
>
|
||||
<Popover
|
||||
open={user?.preferences?.shortcut_enabled && showPrompts}
|
||||
onOpenChange={setShowPrompts}
|
||||
>
|
||||
<Popover.Anchor asChild>
|
||||
<div
|
||||
ref={textAreaWrapperRef}
|
||||
className="px-3 py-2 flex-1 flex h-[2.75rem]"
|
||||
>
|
||||
<textarea
|
||||
id="onyx-chat-input-textarea"
|
||||
role="textarea"
|
||||
ref={textAreaRef}
|
||||
onPaste={handlePaste}
|
||||
onKeyDownCapture={handleKeyDownForPromptShortcuts}
|
||||
onChange={handleInputChange}
|
||||
className={cn(
|
||||
"p-[2px] w-full h-full outline-none bg-transparent resize-none placeholder:text-text-03 whitespace-pre-wrap break-words",
|
||||
"overflow-y-auto"
|
||||
)}
|
||||
autoFocus
|
||||
rows={1}
|
||||
style={{ scrollbarWidth: "thin" }}
|
||||
aria-multiline={true}
|
||||
placeholder={
|
||||
isSearchMode
|
||||
? "Search connected sources"
|
||||
: "How can I help you today?"
|
||||
}
|
||||
value={message}
|
||||
onKeyDown={(event) => {
|
||||
<textarea
|
||||
onPaste={handlePaste}
|
||||
onKeyDownCapture={handleKeyDown}
|
||||
onChange={handleInputChange}
|
||||
ref={textAreaRef}
|
||||
id="onyx-chat-input-textarea"
|
||||
className={cn(
|
||||
"w-full",
|
||||
"outline-none",
|
||||
"bg-transparent",
|
||||
"resize-none",
|
||||
"placeholder:text-text-03",
|
||||
"whitespace-pre-wrap",
|
||||
"break-word",
|
||||
"overscroll-contain",
|
||||
"px-3",
|
||||
isSearchMode
|
||||
? "h-[40px] py-2.5 overflow-hidden"
|
||||
: [
|
||||
"h-[44px]", // Fixed initial height to prevent flash - useEffect will adjust as needed
|
||||
"overflow-y-auto",
|
||||
"pb-2",
|
||||
"pt-3",
|
||||
]
|
||||
)}
|
||||
autoFocus
|
||||
style={{ scrollbarWidth: "thin" }}
|
||||
role="textarea"
|
||||
aria-multiline
|
||||
placeholder={
|
||||
isSearchMode
|
||||
? "Search connected sources"
|
||||
: "How can I help you today"
|
||||
}
|
||||
value={message}
|
||||
onKeyDown={(event) => {
|
||||
if (
|
||||
event.key === "Enter" &&
|
||||
!showPrompts &&
|
||||
!event.shiftKey &&
|
||||
!(event.nativeEvent as any).isComposing
|
||||
) {
|
||||
event.preventDefault();
|
||||
if (
|
||||
event.key === "Enter" &&
|
||||
!showPrompts &&
|
||||
!event.shiftKey &&
|
||||
!(event.nativeEvent as any).isComposing
|
||||
message &&
|
||||
!disabled &&
|
||||
!isClassifying &&
|
||||
!hasUploadingFiles
|
||||
) {
|
||||
event.preventDefault();
|
||||
if (
|
||||
message &&
|
||||
!disabled &&
|
||||
!isClassifying &&
|
||||
!hasUploadingFiles
|
||||
) {
|
||||
onSubmit(message);
|
||||
}
|
||||
onSubmit(message);
|
||||
}
|
||||
}}
|
||||
suppressContentEditableWarning={true}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
}
|
||||
}}
|
||||
suppressContentEditableWarning={true}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</Popover.Anchor>
|
||||
|
||||
<Popover.Content
|
||||
@@ -732,7 +616,214 @@ const AppInputBar = React.memo(
|
||||
)}
|
||||
</div>
|
||||
|
||||
{chatControls}
|
||||
{/* Source chips */}
|
||||
{(selectedDocuments.length > 0 ||
|
||||
filterManager.timeRange ||
|
||||
filterManager.selectedDocumentSets.length > 0) && (
|
||||
<div className="flex gap-x-.5 px-2">
|
||||
<div className="flex gap-x-1 px-2 overflow-visible overflow-x-scroll items-end miniscroll">
|
||||
{filterManager.timeRange && (
|
||||
<SourceChip
|
||||
truncateTitle={false}
|
||||
key="time-range"
|
||||
icon={<SvgCalendar size={12} />}
|
||||
title={`${getFormattedDateRangeString(
|
||||
filterManager.timeRange.from,
|
||||
filterManager.timeRange.to
|
||||
)}`}
|
||||
onRemove={() => {
|
||||
filterManager.setTimeRange(null);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{filterManager.selectedDocumentSets.length > 0 &&
|
||||
filterManager.selectedDocumentSets.map((docSet, index) => (
|
||||
<SourceChip
|
||||
key={`doc-set-${index}`}
|
||||
icon={<SvgFiles size={16} />}
|
||||
title={docSet}
|
||||
onRemove={() => {
|
||||
filterManager.setSelectedDocumentSets(
|
||||
filterManager.selectedDocumentSets.filter(
|
||||
(ds) => ds !== docSet
|
||||
)
|
||||
);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
{selectedDocuments.length > 0 && (
|
||||
<SourceChip
|
||||
key="selected-documents"
|
||||
onClick={() => {
|
||||
toggleDocumentSidebar();
|
||||
}}
|
||||
icon={<SvgFileText size={16} />}
|
||||
title={`${selectedDocuments.length} selected`}
|
||||
onRemove={removeDocs}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isSearchMode && (
|
||||
<div className="flex justify-between items-center w-full p-1 min-h-[40px]">
|
||||
{/* Bottom left controls */}
|
||||
<div className="flex flex-row items-center">
|
||||
{/* (+) button - always visible */}
|
||||
<FilePickerPopover
|
||||
onFileClick={handleFileClick}
|
||||
onPickRecent={(file: ProjectFile) => {
|
||||
// Check if file with same ID already exists
|
||||
if (
|
||||
!currentMessageFiles.some(
|
||||
(existingFile) => existingFile.file_id === file.file_id
|
||||
)
|
||||
) {
|
||||
setCurrentMessageFiles((prev) => [...prev, file]);
|
||||
}
|
||||
}}
|
||||
onUnpickRecent={(file: ProjectFile) => {
|
||||
setCurrentMessageFiles((prev) =>
|
||||
prev.filter(
|
||||
(existingFile) => existingFile.file_id !== file.file_id
|
||||
)
|
||||
);
|
||||
}}
|
||||
handleUploadChange={handleUploadChange}
|
||||
trigger={(open) => (
|
||||
<Button
|
||||
icon={SvgPlusCircle}
|
||||
tooltip="Attach Files"
|
||||
transient={open}
|
||||
disabled={disabled}
|
||||
prominence="tertiary"
|
||||
/>
|
||||
)}
|
||||
selectedFileIds={currentMessageFiles.map((f) => f.id)}
|
||||
/>
|
||||
|
||||
{/* Controls that load in when data is ready */}
|
||||
<div
|
||||
data-testid="actions-container"
|
||||
className={cn(
|
||||
"flex flex-row items-center",
|
||||
controlsLoading && "invisible"
|
||||
)}
|
||||
>
|
||||
{selectedAgent && selectedAgent.tools.length > 0 && (
|
||||
<ActionsPopover
|
||||
selectedAgent={selectedAgent}
|
||||
filterManager={filterManager}
|
||||
availableSources={memoizedAvailableSources}
|
||||
disabled={disabled}
|
||||
/>
|
||||
)}
|
||||
{onToggleTabReading ? (
|
||||
<Button
|
||||
icon={SvgGlobe}
|
||||
onClick={onToggleTabReading}
|
||||
variant="select"
|
||||
selected={tabReadingEnabled}
|
||||
foldable={!tabReadingEnabled}
|
||||
disabled={disabled}
|
||||
>
|
||||
{tabReadingEnabled
|
||||
? currentTabUrl
|
||||
? (() => {
|
||||
try {
|
||||
return new URL(currentTabUrl).hostname;
|
||||
} catch {
|
||||
return currentTabUrl;
|
||||
}
|
||||
})()
|
||||
: "Reading tab..."
|
||||
: "Read this tab"}
|
||||
</Button>
|
||||
) : (
|
||||
showDeepResearch && (
|
||||
<Button
|
||||
icon={SvgHourglass}
|
||||
onClick={toggleDeepResearch}
|
||||
variant="select"
|
||||
selected={deepResearchEnabled}
|
||||
foldable={!deepResearchEnabled}
|
||||
disabled={disabled}
|
||||
>
|
||||
Deep Research
|
||||
</Button>
|
||||
)
|
||||
)}
|
||||
|
||||
{selectedAgent &&
|
||||
forcedToolIds.length > 0 &&
|
||||
forcedToolIds.map((toolId) => {
|
||||
const tool = selectedAgent.tools.find(
|
||||
(tool) => tool.id === toolId
|
||||
);
|
||||
if (!tool) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<Button
|
||||
key={toolId}
|
||||
icon={getIconForAction(tool)}
|
||||
onClick={() => {
|
||||
setForcedToolIds(
|
||||
forcedToolIds.filter((id) => id !== toolId)
|
||||
);
|
||||
}}
|
||||
variant="select"
|
||||
selected
|
||||
disabled={disabled}
|
||||
>
|
||||
{tool.display_name}
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Bottom right controls */}
|
||||
<div className="flex flex-row items-center gap-1">
|
||||
{/* LLM popover - loads when ready */}
|
||||
<div
|
||||
data-testid="AppInputBar/llm-popover-trigger"
|
||||
className={cn(controlsLoading && "invisible")}
|
||||
>
|
||||
<LLMPopover
|
||||
llmManager={llmManager}
|
||||
requiresImageInput={hasImageFiles}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Submit button */}
|
||||
<Button
|
||||
id="onyx-chat-input-send-button"
|
||||
icon={
|
||||
isClassifying
|
||||
? SimpleLoader
|
||||
: chatState === "input"
|
||||
? SvgArrowUp
|
||||
: SvgStop
|
||||
}
|
||||
disabled={
|
||||
(chatState === "input" && !message) ||
|
||||
hasUploadingFiles ||
|
||||
isClassifying
|
||||
}
|
||||
onClick={() => {
|
||||
if (chatState == "streaming") {
|
||||
stopGenerating();
|
||||
} else if (message) {
|
||||
onSubmit(message);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Disabled>
|
||||
);
|
||||
|
||||
@@ -116,6 +116,8 @@ function ViewerOpenApiToolCard({ tool }: { tool: ToolSnapshot }) {
|
||||
);
|
||||
}
|
||||
|
||||
const EMPTY_DOCS: [] = [];
|
||||
|
||||
/**
|
||||
* Floating ChatInputBar below the AgentViewerModal.
|
||||
* On submit, navigates to the agent's chat with the message pre-filled.
|
||||
@@ -135,10 +137,14 @@ function AgentChatInput({ agent, onSubmit }: AgentChatInputProps) {
|
||||
chatState="input"
|
||||
filterManager={filterManager}
|
||||
selectedAgent={agent}
|
||||
selectedDocuments={EMPTY_DOCS}
|
||||
removeDocs={() => {}}
|
||||
stopGenerating={() => {}}
|
||||
handleFileUpload={() => {}}
|
||||
toggleDocumentSidebar={() => {}}
|
||||
currentSessionFileTokenCount={0}
|
||||
availableContextTokens={Infinity}
|
||||
retrievalEnabled={false}
|
||||
deepResearchEnabled={false}
|
||||
toggleDeepResearch={() => {}}
|
||||
disabled={false}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import { test, expect } from "@tests/e2e/fixtures/eeFeatures";
|
||||
|
||||
test.describe("EE Feature Redirect", () => {
|
||||
test("redirects to /chat with toast when EE features are not licensed", async ({
|
||||
page,
|
||||
eeEnabled,
|
||||
}) => {
|
||||
test.skip(eeEnabled, "Redirect only happens without Enterprise license");
|
||||
|
||||
await page.goto("/admin/theme");
|
||||
|
||||
await expect(page).toHaveURL(/\/chat/, { timeout: 10_000 });
|
||||
|
||||
const toastContainer = page.getByTestId("toast-container");
|
||||
await expect(toastContainer).toBeVisible({ timeout: 5_000 });
|
||||
await expect(
|
||||
toastContainer.getByText(/only accessible with a paid license/i)
|
||||
).toBeVisible();
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,4 @@
|
||||
import { test, expect } from "@tests/e2e/fixtures/eeFeatures";
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
test.describe("Appearance Theme Settings @exclusive", () => {
|
||||
@@ -12,21 +12,24 @@ test.describe("Appearance Theme Settings @exclusive", () => {
|
||||
consentPrompt: "I agree to the terms",
|
||||
};
|
||||
|
||||
test.beforeEach(async ({ page, eeEnabled }) => {
|
||||
test.skip(
|
||||
!eeEnabled,
|
||||
"Enterprise license not active — skipping theme tests"
|
||||
);
|
||||
|
||||
// Fresh session — the eeEnabled fixture already logged in to check the
|
||||
// setting, so clear cookies and re-login for a clean test state.
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
|
||||
// Navigate first so localStorage is accessible (API-based login
|
||||
// doesn't navigate, leaving the page on about:blank).
|
||||
await page.goto("/admin/theme");
|
||||
await expect(
|
||||
page.locator('[data-label="application-name-input"]')
|
||||
).toBeVisible({ timeout: 10_000 });
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
// Skip the entire test when Enterprise features are not licensed.
|
||||
// The /admin/theme page is gated behind ee_features_enabled and
|
||||
// renders a license-required message instead of the settings form.
|
||||
const eeLocked = page.getByText(
|
||||
"This functionality requires an active Enterprise license."
|
||||
);
|
||||
if (await eeLocked.isVisible({ timeout: 1000 }).catch(() => false)) {
|
||||
test.skip(true, "Enterprise license not active — skipping theme tests");
|
||||
}
|
||||
|
||||
// Clear localStorage to ensure consent modal shows
|
||||
await page.evaluate(() => {
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
/**
|
||||
* Playwright fixture that detects EE (Enterprise Edition) license state.
|
||||
*
|
||||
* Usage:
|
||||
* ```ts
|
||||
* import { test, expect } from "@tests/e2e/fixtures/eeFeatures";
|
||||
*
|
||||
* test("my EE-gated test", async ({ page, eeEnabled }) => {
|
||||
* test.skip(!eeEnabled, "Requires active Enterprise license");
|
||||
* // ... rest of test
|
||||
* });
|
||||
* ```
|
||||
*
|
||||
* The fixture:
|
||||
* - Authenticates as admin
|
||||
* - Fetches /api/settings to check ee_features_enabled
|
||||
* - Provides a boolean to the test BEFORE any navigation happens
|
||||
*
|
||||
* This lets tests call test.skip() synchronously at the top, which is the
|
||||
* correct Playwright pattern — never navigate then decide to skip.
|
||||
*/
|
||||
|
||||
import { test as base, expect } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
export const test = base.extend<{
|
||||
/** Whether EE features are enabled (valid enterprise license). */
|
||||
eeEnabled: boolean;
|
||||
}>({
|
||||
eeEnabled: async ({ page }, use) => {
|
||||
await loginAs(page, "admin");
|
||||
const res = await page.request.get("/api/settings");
|
||||
if (!res.ok()) {
|
||||
// Fail open — if we can't determine, assume EE is not enabled
|
||||
await use(false);
|
||||
return;
|
||||
}
|
||||
const settings = await res.json();
|
||||
await use(settings.ee_features_enabled === true);
|
||||
},
|
||||
});
|
||||
|
||||
export { expect };
|
||||
Reference in New Issue
Block a user