Compare commits

..

12 Commits

Author SHA1 Message Date
Nik
2cf2c8bbf8 refactor: replace HTTPException with OnyxError in server/documents 2026-03-05 08:52:58 -08:00
Jamison Lahman
7eabfa125c fix(fe): properly wrap copy and edit buttons on mobile (#9073) 2026-03-05 04:36:11 +00:00
SubashMohan
ee18114739 feat(table): add DataTable config-driven wrapper component (#9020)
Co-authored-by: Nik <nikolas.garza5@gmail.com>
2026-03-05 04:21:38 +00:00
Nikolas Garza
f7630f5648 fix: EE route gating for upgrading CE users (#9026) 2026-03-05 03:44:16 +00:00
Jamison Lahman
e0d91b9ea7 chore(fe): rm unreachable code (#9069) 2026-03-05 03:26:50 +00:00
Raunak Bhagat
2c0a4a60a5 refactor: consolidate AppInputBar search/chat rendering with animated transitions (#9054) 2026-03-05 03:16:36 +00:00
Justin Tahara
3a7d4dad56 fix(ui): Improve text truncation and overflow handling in FileCard layout (#9061) 2026-03-05 03:11:53 +00:00
acaprau
c5c236d098 chore(opensearch): Fix and consolidate the dev script used to start OpenSearch locally (#9036) 2026-03-05 01:54:02 +00:00
Danelegend
b18baff4d0 fix: Correct file_id for docs (#9058) 2026-03-05 01:43:58 +00:00
SubashMohan
eb3e15c195 feat(table): add ColumnVisibilityPopover, Footer, Pagination, and SortingPopover components (#9019)
Co-authored-by: Nik <nikolas.garza5@gmail.com>
2026-03-05 01:43:37 +00:00
acaprau
47d9a9e1ac feat(document index): Re-enable search settings swap (#9005) 2026-03-05 01:41:03 +00:00
Evan Lohn
aca466b35d fix: doc to hierarchynode connection in pruning (#9046) 2026-03-05 01:30:36 +00:00
63 changed files with 3115 additions and 1858 deletions

15
.vscode/launch.json vendored
View File

@@ -485,21 +485,6 @@
"group": "3"
}
},
{
"name": "Clear and Restart OpenSearch Container",
// Generic debugger type, required arg but has no bearing on bash.
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Eval CLI",
"type": "debugpy",

View File

@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
@@ -153,12 +152,9 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
# Unified billing API - always registered in EE.
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -39,9 +39,13 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
class SlimConnectorExtractionResult(BaseModel):
"""Result of extracting document IDs and hierarchy nodes from a connector."""
"""Result of extracting document IDs and hierarchy nodes from a connector.
doc_ids: set[str]
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
"""
raw_id_to_parent: dict[str, str | None]
hierarchy_nodes: list[HierarchyNode]
@@ -93,30 +97,37 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
return None
class BatchResult(BaseModel):
raw_id_to_parent: dict[str, str | None]
hierarchy_nodes: list[HierarchyNode]
def _extract_from_batch(
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
) -> tuple[set[str], list[HierarchyNode]]:
"""Separate a batch into document IDs and hierarchy nodes.
) -> BatchResult:
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
ConnectorFailure items have their failed document/entity IDs added to the
ID set so that failed-to-retrieve documents are not accidentally pruned.
ID dict so that failed-to-retrieve documents are not accidentally pruned.
"""
ids: set[str] = set()
ids: dict[str, str | None] = {}
hierarchy_nodes: list[HierarchyNode] = []
for item in doc_list:
if isinstance(item, HierarchyNode):
hierarchy_nodes.append(item)
ids.add(item.raw_node_id)
if item.raw_node_id not in ids:
ids[item.raw_node_id] = None
elif isinstance(item, ConnectorFailure):
failed_id = _get_failure_id(item)
if failed_id:
ids.add(failed_id)
ids[failed_id] = None
logger.warning(
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
)
else:
ids.add(item.id)
return ids, hierarchy_nodes
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
ids[item.id] = parent_raw
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
def extract_ids_from_runnable_connector(
@@ -132,7 +143,7 @@ def extract_ids_from_runnable_connector(
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
all_raw_id_to_parent: dict[str, str | None] = {}
all_hierarchy_nodes: list[HierarchyNode] = []
# Sequence (covariant) lets all the specific list[...] iterator types unify here
@@ -177,15 +188,20 @@ def extract_ids_from_runnable_connector(
"extract_ids_from_runnable_connector: Stop signal detected"
)
batch_ids, batch_nodes = _extract_from_batch(doc_list)
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
for k, v in batch_ids.items():
if v is not None or k not in all_raw_id_to_parent:
all_raw_id_to_parent[k] = v
all_hierarchy_nodes.extend(batch_nodes)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
return SlimConnectorExtractionResult(
doc_ids=all_connector_doc_ids,
raw_id_to_parent=all_raw_id_to_parent,
hierarchy_nodes=all_hierarchy_nodes,
)

View File

@@ -29,6 +29,7 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -47,6 +48,8 @@ from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
@@ -57,6 +60,8 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
@@ -113,6 +118,38 @@ class PruneCallback(IndexingCallbackBase):
super().progress(tag, amount)
def _resolve_and_update_document_parents(
db_session: Session,
redis_client: Redis,
source: DocumentSource,
raw_id_to_parent: dict[str, str | None],
) -> None:
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
each document and bulk-update the DB. Mirrors the resolution logic in
run_docfetching.py."""
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
resolved: dict[str, int | None] = {}
for doc_id, raw_parent_id in raw_id_to_parent.items():
if raw_parent_id is None:
continue
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
resolved[doc_id] = node_id if found else source_node_id
if not resolved:
return
update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
task_logger.info(
f"Pruning: resolved and updated parent hierarchy for "
f"{len(resolved)} documents (source={source.value})"
)
"""Jobs / utils for kicking off pruning tasks."""
@@ -535,22 +572,22 @@ def connector_pruning_generator_task(
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback
)
all_connector_doc_ids = extraction_result.doc_ids
all_connector_doc_ids = extraction_result.raw_id_to_parent
# Process hierarchy nodes (same as docfetching):
# upsert to Postgres and cache in Redis
source = cc_pair.connector.source
redis_client = get_redis_client(tenant_id=tenant_id)
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(
redis_client, db_session, cc_pair.connector.source
)
ensure_source_node_exists(redis_client, db_session, source)
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
source=cc_pair.connector.source,
source=source,
commit=True,
is_connector_public=is_connector_public,
)
@@ -561,7 +598,7 @@ def connector_pruning_generator_task(
]
cache_hierarchy_nodes_batch(
redis_client=redis_client,
source=cc_pair.connector.source,
source=source,
entries=cache_entries,
)
@@ -570,6 +607,26 @@ def connector_pruning_generator_task(
f"hierarchy nodes for cc_pair={cc_pair_id}"
)
ensure_source_node_exists(redis_client, db_session, source)
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
# and bulk-update documents, mirroring the docfetching resolution
_resolve_and_update_document_parents(
db_session=db_session,
redis_client=redis_client,
source=source,
raw_id_to_parent=all_connector_doc_ids,
)
# Link hierarchy nodes to documents for sources where pages can be
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
all_doc_id_list = list(all_connector_doc_ids.keys())
link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=all_doc_id_list,
source=source,
commit=True,
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
@@ -581,7 +638,9 @@ def connector_pruning_generator_task(
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
task_logger.info(
"Pruning set collected: "

View File

@@ -15,7 +15,6 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.emitter import Emitter
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import LlmStepResult
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
from onyx.configs.constants import MessageType
@@ -1019,7 +1018,6 @@ def run_llm_step_pkt_generator(
)
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
arg_scan_offsets: dict[int, int] = {}
reasoning_start = False
answer_start = False
accumulated_reasoning = ""
@@ -1226,14 +1224,7 @@ def run_llm_step_pkt_generator(
yield from _close_reasoning_if_active()
for tool_call_delta in delta.tool_calls:
# maybe_emit depends and update being called first and attaching the delta
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
yield from maybe_emit_argument_delta(
tool_calls_in_progress=id_to_tool_call_map,
tool_call_delta=tool_call_delta,
placement=_current_placement(),
scan_offsets=arg_scan_offsets,
)
# Flush any tail text buffered while checking for split "<function_calls" markers.
filtered_content_tail = xml_tool_call_content_filter.flush()

View File

@@ -1,236 +0,0 @@
import json
from collections.abc import Generator
from collections.abc import Mapping
from typing import Any
from typing import NamedTuple
from typing import Type
from onyx.llm.model_response import ChatCompletionDeltaToolCall
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
from onyx.tools.interface import Tool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_tool_class(
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
tool_call_delta: ChatCompletionDeltaToolCall,
) -> Type[Tool] | None:
"""Look up the Tool subclass for a streaming tool call delta."""
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
if not tool_name:
return None
return TOOL_NAME_TO_CLASS.get(tool_name)
class _Token(NamedTuple):
"""A parsed JSON string with position info."""
value: str # raw content between the quotes
start: int # index of first char inside the quotes
end: int # index of closing quote, or len(text) if incomplete
complete: bool # whether the closing quote was found
def _parse_json_string(text: str, pos: int) -> _Token:
"""Parse a JSON string starting at the opening quote at ``pos``."""
i = pos + 1
while i < len(text):
if text[i] == "\\":
i += 2
elif text[i] == '"':
return _Token(text[pos + 1 : i], pos + 1, i, complete=True)
else:
i += 1
return _Token(text[pos + 1 :], pos + 1, len(text), complete=False)
def _skip_json_value(text: str, pos: int) -> int:
"""Skip past a non-string JSON value (number, bool, null, array, object).
Tracks ``[]`` / ``{}`` nesting depth and skips over embedded strings so
that internal commas and braces don't terminate the scan early. Stops
at the next top-level ``,`` or ``}`` (not consumed).
"""
depth = 0
while pos < len(text):
ch = text[pos]
if ch == '"':
tok = _parse_json_string(text, pos)
pos = tok.end + 1 if tok.complete else tok.end
continue
if ch in ("{", "["):
depth += 1
elif ch in ("}", "]"):
if depth == 0:
break
depth -= 1
elif ch == "," and depth == 0:
break
pos += 1
return pos
def _skip(text: str, pos: int, chars: str = " \t\n\r,") -> int:
"""Advance ``pos`` past any characters in ``chars``."""
while pos < len(text) and text[pos] in chars:
pos += 1
return pos
def _decode_partial_json_string(raw: str) -> str:
"""Decode JSON escapes (``\\n`` → newline) from a possibly incomplete value.
Progressively trims up to 6 trailing chars to handle partial escape
sequences (the longest JSON escape is ``\\uXXXX``).
"""
for trim in range(min(7, len(raw) + 1)):
candidate = raw[: len(raw) - trim] if trim else raw
try:
result = json.loads('"' + candidate + '"')
if trim > 0 and not result and raw:
logger.warning(
"Dropped %d chars from partial JSON string value (trim=%d)",
len(raw),
trim,
)
return result
except (json.JSONDecodeError, ValueError):
continue
logger.warning(
"Failed to decode partial JSON string value; dropping %d chars", len(raw)
)
return ""
def _extract_delta_args(
pre: str, delta: str, scan_offset: int = 0
) -> tuple[dict[str, str], int]:
"""Extract decoded argument values contributed by ``delta``.
Walks ``pre + delta`` as a partial JSON object (``{"k": "v", ...}``),
and for each string value returns only the decoded content that falls
within the ``delta`` portion. Escape sequences that straddle the
boundary are handled correctly.
Returns ``(argument_deltas, next_scan_offset)`` where
``next_scan_offset`` should be passed to the next call to skip
completed key-value pairs, reducing cost from O(accumulated) to
O(delta) per call.
"""
full = pre + delta
delta_start = len(pre)
result: dict[str, str] = {}
if scan_offset > 0:
pos = scan_offset
else:
pos = full.find("{")
if pos == -1:
return result, 0
pos += 1
resume = pos
while pos < len(full):
pos = _skip(full, pos)
if pos >= len(full) or full[pos] == "}":
break
resume = pos # remember start of this key-value pair
# Key
if full[pos] != '"':
break
key = _parse_json_string(full, pos)
if not key.complete:
break
pos = key.end + 1
# Colon
pos = _skip(full, pos, " \t\n\r")
if pos >= len(full) or full[pos] != ":":
break
pos += 1
# Value
pos = _skip(full, pos, " \t\n\r")
if pos >= len(full):
break
if full[pos] != '"':
# Skip non-string values (number, boolean, null, array, object).
# They are available in the final tool-call kickoff packet;
# emitting them here as strings would be ambiguous for consumers
# (e.g. the number 30 vs the string "30").
pos = _skip_json_value(full, pos)
continue
val = _parse_json_string(full, pos)
# Only include the portion of this value that overlaps with delta
lo = max(val.start, delta_start)
hi = val.end
if lo < hi:
# Decode from value start through both boundaries so escape
# sequences straddling the delta edge are handled correctly.
decoded_before = _decode_partial_json_string(full[val.start : lo])
decoded_through = _decode_partial_json_string(full[val.start : hi])
new_content = decoded_through[len(decoded_before) :]
if new_content:
result[key.value] = new_content
if not val.complete:
break
pos = val.end + 1
return result, resume
def maybe_emit_argument_delta(
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
tool_call_delta: ChatCompletionDeltaToolCall,
placement: Placement,
scan_offsets: dict[int, int],
) -> Generator[Packet, None, None]:
"""Emit decoded tool-call argument deltas to the frontend.
NOTE: Currently skips non-string arguments
``scan_offsets`` is a mutable dict keyed by tool-call index that allows
each call to skip past already-processed key-value pairs, reducing
per-call cost from O(accumulated) to O(delta).
"""
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
if not tool_cls or not tool_cls.do_emit_argument_deltas():
return
fn = tool_call_delta.function
delta_fragment = fn.arguments if fn else None
if not delta_fragment:
return
tc_data = tool_calls_in_progress[tool_call_delta.index]
accumulated_args = tc_data["arguments"]
prev_args = accumulated_args[: -len(delta_fragment)]
idx = tool_call_delta.index
offset = scan_offsets.get(idx, 0)
argument_deltas, new_offset = _extract_delta_args(prev_args, delta_fragment, offset)
scan_offsets[idx] = new_offset
if not argument_deltas:
return
yield Packet(
placement=placement,
obj=ToolCallArgumentDelta(
tool_type=tc_data.get("name", ""),
tool_id=tc_data.get("id", ""),
argument_deltas=argument_deltas,
),
)

View File

@@ -943,6 +943,9 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
page
),
)
)
@@ -992,6 +995,7 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=page_id,
)
)

View File

@@ -781,4 +781,5 @@ def build_slim_document(
return SlimDocument(
id=onyx_document_id_from_drive_file(file),
external_access=external_access,
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
)

View File

@@ -902,6 +902,11 @@ class JiraConnector(
external_access=self._get_project_permissions(
project_key, add_prefix=False
),
parent_hierarchy_raw_node_id=(
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
),
)
)
current_offset += 1

View File

@@ -385,6 +385,7 @@ class IndexingDocument(Document):
class SlimDocument(BaseModel):
id: str
external_access: ExternalAccess | None = None
parent_hierarchy_raw_node_id: str | None = None
class HierarchyNode(BaseModel):

View File

@@ -772,6 +772,7 @@ def _convert_driveitem_to_slim_document(
drive_name: str,
ctx: ClientContext,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
) -> SlimDocument:
if driveitem.id is None:
raise ValueError("DriveItem ID is required")
@@ -787,11 +788,15 @@ def _convert_driveitem_to_slim_document(
return SlimDocument(
id=driveitem.id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
def _convert_sitepage_to_slim_document(
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
site_page: dict[str, Any],
ctx: ClientContext | None,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
) -> SlimDocument:
"""Convert a SharePoint site page to a SlimDocument object."""
if site_page.get("id") is None:
@@ -808,6 +813,7 @@ def _convert_sitepage_to_slim_document(
return SlimDocument(
id=id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
@@ -1594,12 +1600,22 @@ class SharepointConnector(
)
)
parent_hierarchy_url: str | None = None
if drive_web_url:
parent_hierarchy_url = self._get_parent_hierarchy_url(
site_url, drive_web_url, drive_name, driveitem
)
try:
logger.debug(f"Processing: {driveitem.web_url}")
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_driveitem_to_slim_document(
driveitem, drive_name, ctx, self.graph_client
driveitem,
drive_name,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
)
)
except Exception as e:
@@ -1619,7 +1635,10 @@ class SharepointConnector(
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page, ctx, self.graph_client
site_page,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
)
)
if len(doc_batch) >= SLIM_BATCH_SIZE:

View File

@@ -565,6 +565,7 @@ def _get_all_doc_ids(
channel_id=channel_id, thread_ts=message["ts"]
),
external_access=external_access,
parent_hierarchy_raw_node_id=channel_id,
)
)

View File

@@ -1,5 +1,7 @@
"""CRUD operations for HierarchyNode."""
from collections import defaultdict
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -525,6 +527,53 @@ def get_document_parent_hierarchy_node_ids(
return {doc_id: parent_id for doc_id, parent_id in results}
def update_document_parent_hierarchy_nodes(
db_session: Session,
doc_parent_map: dict[str, int | None],
commit: bool = True,
) -> int:
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
Only updates rows whose current value differs from the desired value to
avoid unnecessary writes.
Args:
db_session: SQLAlchemy session
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
commit: Whether to commit the transaction
Returns:
Number of documents actually updated
"""
if not doc_parent_map:
return 0
doc_ids = list(doc_parent_map.keys())
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
by_parent: dict[int | None, list[str]] = defaultdict(list)
for doc_id, desired_parent_id in doc_parent_map.items():
current = existing.get(doc_id)
if current == desired_parent_id or doc_id not in existing:
continue
by_parent[desired_parent_id].append(doc_id)
updated = 0
for desired_parent_id, ids in by_parent.items():
db_session.query(Document).filter(Document.id.in_(ids)).update(
{Document.parent_hierarchy_node_id: desired_parent_id},
synchronize_session=False,
)
updated += len(ids)
if commit:
db_session.commit()
elif updated:
db_session.flush()
return updated
def update_hierarchy_node_permissions(
db_session: Session,
raw_node_id: str,

View File

@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
latest_settings = result.scalars().first()
if not latest_settings:
raise RuntimeError("No search settings specified, DB is not in a valid state")
raise RuntimeError("No search settings specified; DB is not in a valid state.")
return latest_settings

View File

@@ -32,9 +32,6 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
multipass = should_use_multipass(search_settings)
enable_large_chunks = SearchSettings.can_use_large_chunks(
multipass, search_settings.model_name, search_settings.provider_type

View File

@@ -26,11 +26,10 @@ def get_default_document_index(
To be used for retrieval only. Indexing should be done through both indices
until Vespa is deprecated.
Pre-existing docstring for this function, although secondary indices are not
currently supported:
Primary index is the index that is used for querying/updating etc. Secondary
index is for when both the currently used index and the upcoming index both
need to be updated, updates are applied to both indices.
need to be updated. Updates are applied to both indices.
WARNING: In that case, get_all_document_indices should be used.
"""
if DISABLE_VECTOR_DB:
return DisabledDocumentIndex(
@@ -51,11 +50,26 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
@@ -86,8 +100,7 @@ def get_all_document_indices(
Used for indexing only. Until Vespa is deprecated we will index into both
document indices. Retrieval is done through only one index however.
Large chunks and secondary indices are not currently supported so we
hardcode appropriate values.
Large chunks are not currently supported so we hardcode appropriate values.
NOTE: Make sure the Vespa index object is returned first. In the rare event
that there is some conflict between indexing and the migration task, it is
@@ -123,13 +136,36 @@ def get_all_document_indices(
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
secondary_index_name=(
secondary_search_settings.index_name
if secondary_search_settings
else None
),
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=(
secondary_search_settings.large_chunks_enabled
if secondary_search_settings
else None
),
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)

View File

@@ -271,6 +271,9 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
secondary_embedding_dim: int | None,
secondary_embedding_precision: EmbeddingPrecision | None,
# NOTE: We do not support large chunks right now.
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
multitenant: bool = False,
@@ -286,12 +289,25 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
f"Expected {MULTI_TENANT}, got {multitenant}."
)
tenant_id = get_current_tenant_id()
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
self._real_index = OpenSearchDocumentIndex(
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
tenant_state=tenant_state,
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
)
self._secondary_real_index: OpenSearchDocumentIndex | None = None
if self.secondary_index_name:
if secondary_embedding_dim is None or secondary_embedding_precision is None:
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
self._secondary_real_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=self.secondary_index_name,
embedding_dim=secondary_embedding_dim,
embedding_precision=secondary_embedding_precision,
)
@staticmethod
def register_multitenant_indices(
@@ -307,19 +323,38 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
self,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None, # noqa: ARG002
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
) -> None:
# Only handle primary index for now, ignore secondary.
return self._real_index.verify_and_create_index_if_necessary(
self._real_index.verify_and_create_index_if_necessary(
primary_embedding_dim, primary_embedding_precision
)
if self.secondary_index_name:
if (
secondary_index_embedding_dim is None
or secondary_index_embedding_precision is None
):
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.verify_and_create_index_if_necessary(
secondary_index_embedding_dim, secondary_index_embedding_precision
)
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
# Convert IndexBatchParams to IndexingMetadata.
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
@@ -351,7 +386,20 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
tenant_id: str, # noqa: ARG002
chunk_count: int | None,
) -> int:
return self._real_index.delete(doc_id, chunk_count)
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
total_chunks_deleted += self._secondary_real_index.delete(
doc_id, chunk_count
)
return total_chunks_deleted
def update_single(
self,
@@ -362,6 +410,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> None:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
f"Tried to update document {doc_id} with no updated fields or user fields."
@@ -392,6 +445,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
try:
self._real_index.update([update_request])
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.update([update_request])
except NotFoundError:
logger.exception(
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "

View File

@@ -465,6 +465,12 @@ class VespaIndex(DocumentIndex):
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
index_batch_params.doc_id_to_new_chunk_cnt
):
@@ -659,6 +665,10 @@ class VespaIndex(DocumentIndex):
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
@@ -679,13 +689,6 @@ class VespaIndex(DocumentIndex):
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
@@ -705,7 +708,20 @@ class VespaIndex(DocumentIndex):
persona_ids=persona_ids,
)
vespa_document_index.update([update_request])
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
vespa_document_index.update([update_request])
def delete_single(
self,
@@ -714,6 +730,11 @@ class VespaIndex(DocumentIndex):
tenant_id: str,
chunk_count: int | None,
) -> int:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
tenant_state = TenantState(
tenant_id=get_current_tenant_id(),
multitenant=MULTI_TENANT,
@@ -726,13 +747,25 @@ class VespaIndex(DocumentIndex):
raise ValueError(
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
total_chunks_deleted = 0
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
total_chunks_deleted += vespa_document_index.delete(
document_id=doc_id, chunk_count=chunk_count
)
return total_chunks_deleted
def id_based_retrieval(
self,

View File

@@ -3,7 +3,6 @@ from http import HTTPStatus
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import JSONResponse
from sqlalchemy import select
@@ -53,6 +52,8 @@ from onyx.db.permission_sync_attempt import (
from onyx.db.permission_sync_attempt import (
get_recent_doc_permission_sync_attempts_for_cc_pair,
)
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_utils import get_deletion_attempt_snapshot
from onyx.redis.redis_pool import get_redis_client
@@ -87,8 +88,9 @@ def get_cc_pair_index_attempts(
cc_pair_id, db_session, user, get_editable=False
)
if not user_has_access:
raise HTTPException(
status_code=400, detail="CC Pair not found for current user permissions"
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"CC Pair not found for current user permissions",
)
total_count = count_index_attempts_for_cc_pair(
@@ -123,8 +125,9 @@ def get_cc_pair_permission_sync_attempts(
cc_pair_id, db_session, user, get_editable=False
)
if not user_has_access:
raise HTTPException(
status_code=400, detail="CC Pair not found for current user permissions"
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"CC Pair not found for current user permissions",
)
# Get all permission sync attempts for this cc pair
@@ -160,8 +163,9 @@ def get_cc_pair_full_info(
cc_pair_id, db_session, user, get_editable=False
)
if not cc_pair:
raise HTTPException(
status_code=404, detail="CC Pair not found for current user permissions"
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"CC Pair not found for current user permissions",
)
editable_cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id, db_session, user, get_editable=True
@@ -264,9 +268,9 @@ def update_cc_pair_status(
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"Connection not found for current user's permissions",
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
@@ -339,8 +343,9 @@ def update_cc_pair_name(
get_editable=True,
)
if not cc_pair:
raise HTTPException(
status_code=400, detail="CC Pair not found for current user's permissions"
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"CC Pair not found for current user's permissions",
)
try:
@@ -351,7 +356,7 @@ def update_cc_pair_name(
)
except IntegrityError:
db_session.rollback()
raise HTTPException(status_code=400, detail="Name must be unique")
raise OnyxError(OnyxErrorCode.CONFLICT, "Name must be unique")
@router.put("/admin/cc-pair/{cc_pair_id}/property")
@@ -368,8 +373,9 @@ def update_cc_pair_property(
get_editable=True,
)
if not cc_pair:
raise HTTPException(
status_code=400, detail="CC Pair not found for current user's permissions"
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"CC Pair not found for current user's permissions",
)
# Can we centralize logic for updating connector properties
@@ -387,8 +393,9 @@ def update_cc_pair_property(
msg = "Pruning frequency updated successfully"
else:
raise HTTPException(
status_code=400, detail=f"Property name {update_request.name} is not valid."
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Property name {update_request.name} is not valid.",
)
return StatusResponse(success=True, message=msg, data=cc_pair_id)
@@ -407,9 +414,9 @@ def get_cc_pair_last_pruned(
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"cc_pair not found for current user's permissions",
)
return cc_pair.last_pruned
@@ -431,19 +438,16 @@ def prune_cc_pair(
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"Connection not found for current user's permissions",
)
r = get_redis_client()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.prune.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Pruning task already in progress.",
)
raise OnyxError(OnyxErrorCode.CONFLICT, "Pruning task already in progress.")
logger.info(
f"Pruning cc_pair: cc_pair={cc_pair_id} "
@@ -455,10 +459,7 @@ def prune_cc_pair(
client_app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Pruning task creation failed.",
)
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Pruning task creation failed.")
logger.info(f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}")
@@ -588,20 +589,21 @@ def associate_credential_to_connector(
delete_connector(db_session, connector_id)
db_session.commit()
raise HTTPException(
status_code=400, detail="Connector validation error: " + str(e)
raise OnyxError(
OnyxErrorCode.CONNECTOR_VALIDATION_FAILED,
"Connector validation error: " + str(e),
)
except IntegrityError as e:
logger.error(f"IntegrityError: {e}")
delete_connector(db_session, connector_id)
db_session.commit()
raise HTTPException(status_code=400, detail="Name must be unique")
raise OnyxError(OnyxErrorCode.CONFLICT, "Name must be unique")
except Exception as e:
logger.exception(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail="Unexpected error")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "Unexpected error")
@router.delete(

View File

@@ -11,7 +11,6 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Form
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from fastapi import Response
@@ -115,6 +114,8 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.file_types import PLAIN_TEXT_MIME_TYPE
from onyx.file_processing.file_types import WORD_PROCESSING_MIME_TYPE
from onyx.file_store.file_store import FileStore
@@ -180,7 +181,7 @@ def check_google_app_gmail_credentials_exist(
try:
return {"client_id": get_google_app_cred(DocumentSource.GMAIL).web.client_id}
except KvKeyNotFoundError:
raise HTTPException(status_code=404, detail="Google App Credentials not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Google App Credentials not found")
@router.put("/admin/connector/gmail/app-credential")
@@ -190,7 +191,7 @@ def upsert_google_app_gmail_credentials(
try:
upsert_google_app_cred(app_credentials, DocumentSource.GMAIL)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return StatusResponse(
success=True, message="Successfully saved Google App Credentials"
@@ -206,7 +207,7 @@ def delete_google_app_gmail_credentials(
delete_google_app_cred(DocumentSource.GMAIL)
cleanup_gmail_credentials(db_session=db_session)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return StatusResponse(
success=True, message="Successfully deleted Google App Credentials"
@@ -222,7 +223,7 @@ def check_google_app_credentials_exist(
"client_id": get_google_app_cred(DocumentSource.GOOGLE_DRIVE).web.client_id
}
except KvKeyNotFoundError:
raise HTTPException(status_code=404, detail="Google App Credentials not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Google App Credentials not found")
@router.put("/admin/connector/google-drive/app-credential")
@@ -232,7 +233,7 @@ def upsert_google_app_credentials(
try:
upsert_google_app_cred(app_credentials, DocumentSource.GOOGLE_DRIVE)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return StatusResponse(
success=True, message="Successfully saved Google App Credentials"
@@ -248,7 +249,7 @@ def delete_google_app_credentials(
delete_google_app_cred(DocumentSource.GOOGLE_DRIVE)
cleanup_google_drive_credentials(db_session=db_session)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return StatusResponse(
success=True, message="Successfully deleted Google App Credentials"
@@ -266,9 +267,7 @@ def check_google_service_gmail_account_key_exist(
).client_email
}
except KvKeyNotFoundError:
raise HTTPException(
status_code=404, detail="Google Service Account Key not found"
)
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Google Service Account Key not found")
@router.put("/admin/connector/gmail/service-account-key")
@@ -278,7 +277,7 @@ def upsert_google_service_gmail_account_key(
try:
upsert_service_account_key(service_account_key, DocumentSource.GMAIL)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return StatusResponse(
success=True, message="Successfully saved Google Service Account Key"
@@ -294,7 +293,7 @@ def delete_google_service_gmail_account_key(
delete_service_account_key(DocumentSource.GMAIL)
cleanup_gmail_credentials(db_session=db_session)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return StatusResponse(
success=True, message="Successfully deleted Google Service Account Key"
@@ -312,9 +311,7 @@ def check_google_service_account_key_exist(
).client_email
}
except KvKeyNotFoundError:
raise HTTPException(
status_code=404, detail="Google Service Account Key not found"
)
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Google Service Account Key not found")
@router.put("/admin/connector/google-drive/service-account-key")
@@ -324,7 +321,7 @@ def upsert_google_service_account_key(
try:
upsert_service_account_key(service_account_key, DocumentSource.GOOGLE_DRIVE)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return StatusResponse(
success=True, message="Successfully saved Google Service Account Key"
@@ -340,7 +337,7 @@ def delete_google_service_account_key(
delete_service_account_key(DocumentSource.GOOGLE_DRIVE)
cleanup_google_drive_credentials(db_session=db_session)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return StatusResponse(
success=True, message="Successfully deleted Google Service Account Key"
@@ -363,7 +360,7 @@ def upsert_service_account_credential(
name="Service Account (uploaded)",
)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
# first delete all existing service account credentials
delete_service_account_credentials(user, db_session, DocumentSource.GOOGLE_DRIVE)
@@ -389,7 +386,7 @@ def upsert_gmail_service_account_credential(
primary_admin_email=service_account_credential_request.google_primary_admin,
)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
# first delete all existing service account credentials
delete_service_account_credentials(user, db_session, DocumentSource.GMAIL)
@@ -440,9 +437,9 @@ def save_zip_metadata_to_file_store(
json.loads(metadata_bytes)
except json.JSONDecodeError as e:
logger.warning(f"Unable to load {ONYX_METADATA_FILENAME}: {e}")
raise HTTPException(
status_code=400,
detail=f"Unable to load {ONYX_METADATA_FILENAME}: {e}",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Unable to load {ONYX_METADATA_FILENAME}: {e}",
)
# Save to file store
@@ -500,7 +497,7 @@ def upload_files(
if is_zip_file(file):
if seen_zip:
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, SEEN_ZIP_DETAIL)
seen_zip = True
with zipfile.ZipFile(file.file, "r") as zf:
zip_metadata_file_id = save_zip_metadata_to_file_store(
@@ -554,7 +551,7 @@ def upload_files(
deduped_file_names.append(file.filename)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
return FileUploadResponse(
file_paths=deduped_file_paths,
file_names=deduped_file_names,
@@ -581,9 +578,9 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
) -> ConnectorCredentialPair:
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if cc_pair is None:
raise HTTPException(
status_code=404,
detail="No Connector-Credential Pair found for this connector",
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
"No Connector-Credential Pair found for this connector",
)
has_requested_access = verify_user_has_access_to_cc_pair(
@@ -604,9 +601,9 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
):
return cc_pair
raise HTTPException(
status_code=403,
detail="Access denied. User cannot manage files for this connector.",
raise OnyxError(
OnyxErrorCode.UNAUTHORIZED,
"Access denied. User cannot manage files for this connector.",
)
@@ -627,11 +624,12 @@ def list_connector_files(
"""List all files in a file connector."""
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector not found")
raise OnyxError(OnyxErrorCode.CONNECTOR_NOT_FOUND, "Connector not found")
if connector.source != DocumentSource.FILE:
raise HTTPException(
status_code=400, detail="This endpoint only works with file connectors"
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"This endpoint only works with file connectors",
)
_ = _fetch_and_check_file_connector_cc_pair_permissions(
@@ -700,11 +698,12 @@ def update_connector_files(
files = files or []
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector not found")
raise OnyxError(OnyxErrorCode.CONNECTOR_NOT_FOUND, "Connector not found")
if connector.source != DocumentSource.FILE:
raise HTTPException(
status_code=400, detail="This endpoint only works with file connectors"
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"This endpoint only works with file connectors",
)
# Get the connector-credential pair for indexing/pruning triggers
@@ -720,12 +719,14 @@ def update_connector_files(
try:
file_ids_list = json.loads(file_ids_to_remove)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid file_ids_to_remove format")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR, "Invalid file_ids_to_remove format"
)
if not isinstance(file_ids_list, list):
raise HTTPException(
status_code=400,
detail="file_ids_to_remove must be a JSON-encoded list",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"file_ids_to_remove must be a JSON-encoded list",
)
# Get current connector config
@@ -750,9 +751,9 @@ def update_connector_files(
current_zip_metadata = loaded_metadata
except Exception as e:
logger.warning(f"Failed to load existing metadata file: {e}")
raise HTTPException(
status_code=500,
detail="Failed to load existing connector metadata file",
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Failed to load existing connector metadata file",
)
# Upload new files if any
@@ -807,9 +808,9 @@ def update_connector_files(
# Validate that at least one file remains
if not final_file_locations:
raise HTTPException(
status_code=400,
detail="Cannot remove all files from connector. At least one file must remain.",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Cannot remove all files from connector. At least one file must remain.",
)
# Merge and filter metadata (remove metadata for deleted files)
@@ -852,8 +853,8 @@ def update_connector_files(
updated_connector = update_connector(connector_id, connector_base, db_session)
if updated_connector is None:
raise HTTPException(
status_code=500, detail="Failed to update connector configuration"
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR, "Failed to update connector configuration"
)
# Trigger re-indexing for new files and pruning for removed files
@@ -1541,7 +1542,7 @@ def create_connector_from_model(
return connector_response
except ValueError as e:
logger.error(f"Error creating connector: {e}")
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
@router.post("/admin/connector-with-mock-credential")
@@ -1619,11 +1620,12 @@ def create_connector_with_mock_credential(
return response
except ConnectorValidationError as e:
raise HTTPException(
status_code=400, detail="Connector validation error: " + str(e)
raise OnyxError(
OnyxErrorCode.CONNECTOR_VALIDATION_FAILED,
"Connector validation error: " + str(e),
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
@router.patch("/admin/connector/{connector_id}", tags=PUBLIC_API_TAGS)
@@ -1648,12 +1650,13 @@ def update_connector_from_model(
)
connector_base = connector_data.to_connector_base()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
updated_connector = update_connector(connector_id, connector_base, db_session)
if updated_connector is None:
raise HTTPException(
status_code=404, detail=f"Connector {connector_id} does not exist"
raise OnyxError(
OnyxErrorCode.CONNECTOR_NOT_FOUND,
f"Connector {connector_id} does not exist",
)
return ConnectorSnapshot(
@@ -1690,7 +1693,7 @@ def delete_connector_by_id(
connector_id=connector_id,
)
except AssertionError:
raise HTTPException(status_code=400, detail="Connector is not deletable")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Connector is not deletable")
@router.post("/admin/connector/run-once", tags=PUBLIC_API_TAGS)
@@ -1711,9 +1714,9 @@ def connector_run_once(
run_info.connector_id, db_session
)
except ValueError:
raise HTTPException(
status_code=404,
detail=f"Connector by id {connector_id} does not exist.",
raise OnyxError(
OnyxErrorCode.CONNECTOR_NOT_FOUND,
f"Connector by id {connector_id} does not exist.",
)
if not specified_credential_ids:
@@ -1722,15 +1725,15 @@ def connector_run_once(
if set(specified_credential_ids).issubset(set(possible_credential_ids)):
credential_ids = specified_credential_ids
else:
raise HTTPException(
status_code=400,
detail="Not all specified credentials are associated with connector",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Not all specified credentials are associated with connector",
)
if not credential_ids:
raise HTTPException(
status_code=400,
detail="Connector has no valid credentials, cannot create index attempts.",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Connector has no valid credentials, cannot create index attempts.",
)
try:
num_triggers = trigger_indexing_for_cc_pair(
@@ -1741,7 +1744,7 @@ def connector_run_once(
db_session,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
logger.info("connector_run_once - running check_for_indexing")
@@ -1795,8 +1798,8 @@ def gmail_callback(
) -> StatusResponse:
credential_id_cookie = request.cookies.get(_GMAIL_CREDENTIAL_ID_COOKIE_NAME)
if credential_id_cookie is None or not credential_id_cookie.isdigit():
raise HTTPException(
status_code=401, detail="Request did not pass CSRF verification."
raise OnyxError(
OnyxErrorCode.CSRF_FAILURE, "Request did not pass CSRF verification."
)
credential_id = int(credential_id_cookie)
verify_csrf(credential_id, callback.state)
@@ -1809,8 +1812,8 @@ def gmail_callback(
GoogleOAuthAuthenticationMethod.UPLOADED,
)
if credentials is None:
raise HTTPException(
status_code=500, detail="Unable to fetch Gmail access tokens"
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR, "Unable to fetch Gmail access tokens"
)
return StatusResponse(success=True, message="Updated Gmail access tokens")
@@ -1825,8 +1828,8 @@ def google_drive_callback(
) -> StatusResponse:
credential_id_cookie = request.cookies.get(_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME)
if credential_id_cookie is None or not credential_id_cookie.isdigit():
raise HTTPException(
status_code=401, detail="Request did not pass CSRF verification."
raise OnyxError(
OnyxErrorCode.CSRF_FAILURE, "Request did not pass CSRF verification."
)
credential_id = int(credential_id_cookie)
verify_csrf(credential_id, callback.state)
@@ -1840,8 +1843,9 @@ def google_drive_callback(
GoogleOAuthAuthenticationMethod.UPLOADED,
)
if credentials is None:
raise HTTPException(
status_code=500, detail="Unable to fetch Google Drive access tokens"
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Unable to fetch Google Drive access tokens",
)
return StatusResponse(success=True, message="Updated Google Drive access tokens")
@@ -1881,8 +1885,9 @@ def get_connector_by_id(
) -> ConnectorSnapshot | StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
raise HTTPException(
status_code=404, detail=f"Connector {connector_id} does not exist"
raise OnyxError(
OnyxErrorCode.CONNECTOR_NOT_FOUND,
f"Connector {connector_id} does not exist",
)
return ConnectorSnapshot(
@@ -1915,7 +1920,9 @@ def submit_connector_request(
connector_name = request_data.connector_name.strip()
if not connector_name:
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR, "Connector name cannot be empty"
)
# Get user identifier for telemetry
user_email = user.email if user else None

View File

@@ -4,7 +4,6 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Form
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from sqlalchemy.orm import Session
@@ -28,6 +27,8 @@ from onyx.db.credentials import update_credential
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import DocumentSource
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.documents.models import CredentialBase
from onyx.server.documents.models import CredentialDataUpdateRequest
from onyx.server.documents.models import CredentialSnapshot
@@ -176,18 +177,18 @@ def create_credential_with_private_key(
try:
credential_data = json.loads(credential_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON in credential_json: {str(e)}",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Invalid JSON in credential_json: {str(e)}",
)
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
)
if private_key_processor is None:
raise HTTPException(
status_code=400,
detail="Invalid type definition key for private key file",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Invalid type definition key for private key file",
)
private_key_content: str = private_key_processor(uploaded_file)
@@ -251,9 +252,9 @@ def get_credential_by_id(
get_editable=False,
)
if credential is None:
raise HTTPException(
status_code=401,
detail=f"Credential {credential_id} does not exist or does not belong to user",
raise OnyxError(
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
f"Credential {credential_id} does not exist or does not belong to user",
)
return CredentialSnapshot.from_credential_db_model(credential)
@@ -275,9 +276,9 @@ def update_credential_data(
)
if credential is None:
raise HTTPException(
status_code=401,
detail=f"Credential {credential_id} does not exist or does not belong to user",
raise OnyxError(
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
f"Credential {credential_id} does not exist or does not belong to user",
)
return CredentialSnapshot.from_credential_db_model(credential)
@@ -297,18 +298,18 @@ def update_credential_private_key(
try:
credential_data = json.loads(credential_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON in credential_json: {str(e)}",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Invalid JSON in credential_json: {str(e)}",
)
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
)
if private_key_processor is None:
raise HTTPException(
status_code=400,
detail="Invalid type definition key for private key file",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Invalid type definition key for private key file",
)
private_key_content: str = private_key_processor(uploaded_file)
credential_data[field_key] = private_key_content
@@ -322,9 +323,9 @@ def update_credential_private_key(
)
if credential is None:
raise HTTPException(
status_code=401,
detail=f"Credential {credential_id} does not exist or does not belong to user",
raise OnyxError(
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
f"Credential {credential_id} does not exist or does not belong to user",
)
return CredentialSnapshot.from_credential_db_model(credential)
@@ -341,9 +342,9 @@ def update_credential_from_model(
credential_id, credential_data, user, db_session
)
if updated_credential is None:
raise HTTPException(
status_code=401,
detail=f"Credential {credential_id} does not exist or does not belong to user",
raise OnyxError(
OnyxErrorCode.CREDENTIAL_NOT_FOUND,
f"Credential {credential_id} does not exist or does not belong to user",
)
# Get credential_json value - use masking for API responses

View File

@@ -1,6 +1,5 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from sqlalchemy.orm import Session
@@ -14,6 +13,8 @@ from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.prompt_utils import build_doc_context_str
from onyx.server.documents.models import ChunkInfo
@@ -43,7 +44,7 @@ def get_document_info(
)
if not inference_chunks:
raise HTTPException(status_code=404, detail="Document not found")
raise OnyxError(OnyxErrorCode.DOCUMENT_NOT_FOUND, "Document not found")
contents = [chunk.content for chunk in inference_chunks]
@@ -95,7 +96,7 @@ def get_chunk_info(
)
if not inference_chunks:
raise HTTPException(status_code=404, detail="Chunk not found")
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Chunk not found")
chunk_content = inference_chunks[0].content

View File

@@ -2,9 +2,10 @@ import base64
from enum import Enum
from typing import Protocol
from fastapi import HTTPException
from fastapi import UploadFile
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.documents.document_utils import validate_pkcs12_content
@@ -31,8 +32,9 @@ def process_sharepoint_private_key_file(file: UploadFile) -> str:
"""
# First check file extension (basic filter)
if not (file.filename and file.filename.lower().endswith(".pfx")):
raise HTTPException(
status_code=400, detail="Invalid file type. Only .pfx files are supported."
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Invalid file type. Only .pfx files are supported.",
)
# Read file content for validation and processing
@@ -40,9 +42,9 @@ def process_sharepoint_private_key_file(file: UploadFile) -> str:
# Validate file content to prevent extension spoofing attacks
if not validate_pkcs12_content(private_key_bytes):
raise HTTPException(
status_code=400,
detail="Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.",
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.",
)
# Convert to base64 if validation passes

View File

@@ -5,7 +5,6 @@ from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from pydantic import BaseModel
@@ -19,6 +18,8 @@ from onyx.connectors.interfaces import OAuthConnector
from onyx.db.credentials import create_credential
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
@@ -69,12 +70,10 @@ def _get_additional_kwargs(
# validate
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
except ValidationError:
raise HTTPException(
status_code=400,
detail=(
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
),
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}",
)
return additional_kwargs_dict
@@ -97,7 +96,9 @@ def oauth_authorize(
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR, f"Unknown OAuth source: {source}"
)
connector_cls = oauth_connectors[source]
base_url = WEB_DOMAIN
@@ -147,7 +148,9 @@ def oauth_callback(
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR, f"Unknown OAuth source: {source}"
)
connector_cls = oauth_connectors[source]
@@ -157,7 +160,7 @@ def oauth_callback(
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
)
if not oauth_state_bytes:
raise HTTPException(status_code=400, detail="Invalid OAuth state")
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Invalid OAuth state")
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))
desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])

View File

@@ -6,8 +6,11 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.engine.sql_engine import get_session
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_existing_llm_provider
@@ -15,20 +18,25 @@ from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import update_no_default_contextual_rag_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import User
from onyx.db.search_settings import create_search_settings
from onyx.db.search_settings import delete_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_embedding_provider_from_provider_type
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.factory import get_default_document_index
from onyx.file_processing.unstructured import delete_unstructured_api_key
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import update_unstructured_api_key
from onyx.natural_language_processing.search_nlp_models import clean_model_name
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
from onyx.server.manage.models import FullModelVersionResponse
from onyx.server.models import IdReturn
from onyx.server.utils_vector_db import require_vector_db
from onyx.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
from shared_configs.configs import MULTI_TENANT
router = APIRouter(prefix="/search-settings")
@@ -41,110 +49,99 @@ def set_new_search_settings(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session), # noqa: ARG001
) -> IdReturn:
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
Gives an error if the same model name is used as the current or secondary index
"""
# TODO(andrei): Re-enable.
# NOTE Enable integration external dependency tests in test_search_settings.py
# when this is reenabled. They are currently skipped
logger.error("Setting new search settings is temporarily disabled.")
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Setting new search settings is temporarily disabled.",
Creates a new SearchSettings row and cancels the previous secondary indexing
if any exists.
"""
if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested")
# Disallow contextual RAG for cloud deployments.
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
)
# Validate cloud provider exists or create new LiteLLM provider.
if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=search_settings_new.provider_type
)
if cloud_provider is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
)
validate_contextual_rag_model(
provider_name=search_settings_new.contextual_rag_llm_provider,
model_name=search_settings_new.contextual_rag_llm_name,
db_session=db_session,
)
# if search_settings_new.index_name:
# logger.warning("Index name was specified by request, this is not suggested")
# # Disallow contextual RAG for cloud deployments
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Contextual RAG disabled in Onyx Cloud",
# )
search_settings = get_current_search_settings(db_session)
# # Validate cloud provider exists or create new LiteLLM provider
# if search_settings_new.provider_type is not None:
# cloud_provider = get_embedding_provider_from_provider_type(
# db_session, provider_type=search_settings_new.provider_type
# )
if search_settings_new.index_name is None:
# We define index name here.
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
if (
search_settings_new.model_name == search_settings.model_name
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
):
index_name += ALT_INDEX_SUFFIX
search_values = search_settings_new.model_dump()
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(
**search_settings_new.model_dump()
)
# if cloud_provider is None:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
# )
secondary_search_settings = get_secondary_search_settings(db_session)
# validate_contextual_rag_model(
# provider_name=search_settings_new.contextual_rag_llm_provider,
# model_name=search_settings_new.contextual_rag_llm_name,
# db_session=db_session,
# )
if secondary_search_settings:
# Cancel any background indexing jobs.
expire_index_attempts(
search_settings_id=secondary_search_settings.id, db_session=db_session
)
# search_settings = get_current_search_settings(db_session)
# Mark previous model as a past model directly.
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
# if search_settings_new.index_name is None:
# # We define index name here
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
# if (
# search_settings_new.model_name == search_settings.model_name
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
# ):
# index_name += ALT_INDEX_SUFFIX
# search_values = search_settings_new.model_dump()
# search_values["index_name"] = index_name
# new_search_settings_request = SavedSearchSettings(**search_values)
# else:
# new_search_settings_request = SavedSearchSettings(
# **search_settings_new.model_dump()
# )
new_search_settings = create_search_settings(
search_settings=new_search_settings_request, db_session=db_session
)
# secondary_search_settings = get_secondary_search_settings(db_session)
# Ensure the document indices have the new index immediately.
document_indices = get_all_document_indices(search_settings, new_search_settings)
for document_index in document_indices:
document_index.ensure_indices_exist(
primary_embedding_dim=search_settings.final_embedding_dim,
primary_embedding_precision=search_settings.embedding_precision,
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
secondary_index_embedding_precision=new_search_settings.embedding_precision,
)
# if secondary_search_settings:
# # Cancel any background indexing jobs
# expire_index_attempts(
# search_settings_id=secondary_search_settings.id, db_session=db_session
# )
# Pause index attempts for the currently in-use index to preserve resources.
if DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=new_search_settings.id,
db_session=db_session,
)
# # Mark previous model as a past model directly
# update_search_settings_status(
# search_settings=secondary_search_settings,
# new_status=IndexModelStatus.PAST,
# db_session=db_session,
# )
# new_search_settings = create_search_settings(
# search_settings=new_search_settings_request, db_session=db_session
# )
# # Ensure Vespa has the new index immediately
# get_multipass_config(search_settings)
# get_multipass_config(new_search_settings)
# document_index = get_default_document_index(
# search_settings, new_search_settings, db_session
# )
# document_index.ensure_indices_exist(
# primary_embedding_dim=search_settings.final_embedding_dim,
# primary_embedding_precision=search_settings.embedding_precision,
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
# )
# # Pause index attempts for the currently in use index to preserve resources
# if DISABLE_INDEX_UPDATE_ON_SWAP:
# expire_index_attempts(
# search_settings_id=search_settings.id, db_session=db_session
# )
# for cc_pair in get_connector_credential_pairs(db_session):
# resync_cc_pair(
# cc_pair=cc_pair,
# search_settings_id=new_search_settings.id,
# db_session=db_session,
# )
# db_session.commit()
# return IdReturn(id=new_search_settings.id)
db_session.commit()
return IdReturn(id=new_search_settings.id)
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])

View File

@@ -1,6 +1,5 @@
import datetime
import json
import os
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
@@ -61,7 +60,6 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -812,18 +810,6 @@ def fetch_chat_file(
if not file_record:
raise HTTPException(status_code=404, detail="File not found")
original_file_name = file_record.display_name
if file_record.file_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
# Check if a converted text file exists for .docx files
txt_file_name = docx_to_txt_filename(original_file_name)
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
txt_file_record = file_store.read_file_record(txt_file_id)
if txt_file_record:
file_record = txt_file_record
file_id = txt_file_id
media_type = file_record.file_type
file_io = file_store.read_file(file_id, mode="b")

View File

@@ -41,7 +41,6 @@ class StreamingType(Enum):
REASONING_DONE = "reasoning_done"
CITATION_INFO = "citation_info"
TOOL_CALL_DEBUG = "tool_call_debug"
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
MEMORY_TOOL_START = "memory_tool_start"
MEMORY_TOOL_DELTA = "memory_tool_delta"
@@ -260,16 +259,6 @@ class CustomToolDelta(BaseObj):
file_ids: list[str] | None = None
class ToolCallArgumentDelta(BaseObj):
type: Literal["tool_call_argument_delta"] = (
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
)
tool_type: str
tool_id: str
argument_deltas: dict[str, Any]
################################################
# File Reader Packets
################################################
@@ -390,7 +379,6 @@ PacketObj = Union[
# Citation Packets
CitationInfo,
ToolCallDebug,
ToolCallArgumentDelta,
# Deep Research Packets
DeepResearchPlanStart,
DeepResearchPlanDelta,

View File

@@ -60,9 +60,11 @@ class Settings(BaseModel):
deep_research_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
# Whether EE features are unlocked for use.
# Depends on license status: True when the user has a valid license
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
# or the license is expired (GATED_ACCESS).
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
ee_features_enabled: bool = False
temperature_override_enabled: bool | None = False

View File

@@ -56,23 +56,3 @@ def get_built_in_tool_ids() -> list[str]:
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
return BUILT_IN_TOOL_MAP[in_code_tool_id]
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
"""Build a mapping from LLM-facing tool name to tool class."""
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
for cls in BUILT_IN_TOOL_MAP.values():
name_attr = cls.__dict__.get("name")
if isinstance(name_attr, property) and name_attr.fget is not None:
tool_name = name_attr.fget(cls)
elif isinstance(name_attr, str):
tool_name = name_attr
else:
raise ValueError(
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
)
result[tool_name] = cls
return result
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()

View File

@@ -92,7 +92,3 @@ class Tool(abc.ABC, Generic[TOverride]):
**llm_kwargs: Any,
) -> ToolResponse:
raise NotImplementedError
@classmethod
def do_emit_argument_deltas(cls) -> bool:
return False

View File

@@ -376,8 +376,3 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
rich_response=None,
llm_facing_response=llm_response,
)
@classmethod
@override
def do_emit_argument_deltas(cls) -> bool:
return True

View File

@@ -1,10 +1,20 @@
#!/bin/bash
set -e
cleanup() {
echo "Error occurred. Cleaning up..."
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
COMPOSE_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.yml"
COMPOSE_DEV_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.dev.yml"
stop_and_remove_containers() {
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled stop opensearch 2>/dev/null || true
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled rm -f opensearch 2>/dev/null || true
}
cleanup() {
echo "Error occurred. Cleaning up..."
stop_and_remove_containers
}
# Trap errors and output a message, then cleanup
@@ -12,16 +22,26 @@ trap 'echo "Error occurred on line $LINENO. Exiting script." >&2; cleanup' ERR
# Usage of the script with optional volume arguments
# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume]
# [minio_volume] [--keep-opensearch-data]
VESPA_VOLUME=${1:-""} # Default is empty if not provided
POSTGRES_VOLUME=${2:-""} # Default is empty if not provided
REDIS_VOLUME=${3:-""} # Default is empty if not provided
MINIO_VOLUME=${4:-""} # Default is empty if not provided
KEEP_OPENSEARCH_DATA=false
POSITIONAL_ARGS=()
for arg in "$@"; do
if [[ "$arg" == "--keep-opensearch-data" ]]; then
KEEP_OPENSEARCH_DATA=true
else
POSITIONAL_ARGS+=("$arg")
fi
done
VESPA_VOLUME=${POSITIONAL_ARGS[0]:-""}
POSTGRES_VOLUME=${POSITIONAL_ARGS[1]:-""}
REDIS_VOLUME=${POSITIONAL_ARGS[2]:-""}
MINIO_VOLUME=${POSITIONAL_ARGS[3]:-""}
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
stop_and_remove_containers
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -39,6 +59,29 @@ else
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
fi
# If OPENSEARCH_ADMIN_PASSWORD is not already set, try loading it from
# .vscode/.env so existing dev setups that stored it there aren't silently
# broken.
VSCODE_ENV="$SCRIPT_DIR/../../.vscode/.env"
if [[ -z "${OPENSEARCH_ADMIN_PASSWORD:-}" && -f "$VSCODE_ENV" ]]; then
set -a
# shellcheck source=/dev/null
source "$VSCODE_ENV"
set +a
fi
# Start the OpenSearch container using the same service from docker-compose that
# our users use, setting OPENSEARCH_INITIAL_ADMIN_PASSWORD from the env's
# OPENSEARCH_ADMIN_PASSWORD if it exists, else defaulting to StrongPassword123!.
# Pass --keep-opensearch-data to preserve the opensearch-data volume across
# restarts, else the volume is deleted so the container starts fresh.
if [[ "$KEEP_OPENSEARCH_DATA" == "false" ]]; then
echo "Deleting opensearch-data volume..."
docker volume rm onyx_opensearch-data 2>/dev/null || true
fi
echo "Starting OpenSearch container..."
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled up --force-recreate -d opensearch
# Start the Redis container with optional volume
echo "Starting Redis container..."
if [[ -n "$REDIS_VOLUME" ]]; then
@@ -60,7 +103,6 @@ echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
cd "$PARENT_DIR"

View File

@@ -1,10 +0,0 @@
#!/bin/bash
# We get OPENSEARCH_ADMIN_PASSWORD from the repo .env file.
source "$(dirname "$0")/../../.vscode/.env"
cd "$(dirname "$0")/../../deployment/docker_compose"
# Start OpenSearch.
echo "Forcefully starting fresh OpenSearch container..."
docker compose -f docker-compose.opensearch.yml up --force-recreate -d opensearch

View File

@@ -5,6 +5,8 @@ Verifies that:
1. extract_ids_from_runnable_connector correctly separates hierarchy nodes from doc IDs
2. Extracted hierarchy nodes are correctly upserted to Postgres via upsert_hierarchy_nodes_batch
3. Upserting is idempotent (running twice doesn't duplicate nodes)
4. Document-to-hierarchy-node linkage is updated during pruning
5. link_hierarchy_nodes_to_documents links nodes that are also documents
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
combined with a real PostgreSQL database for verifying persistence.
@@ -27,9 +29,13 @@ from onyx.db.enums import HierarchyNodeType
from onyx.db.hierarchy import ensure_source_node_exists
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import Document as DbDocument
from onyx.db.models import HierarchyNode as DBHierarchyNode
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.kg.models import KGStage
# ---------------------------------------------------------------------------
# Constants
@@ -89,8 +95,18 @@ def _make_hierarchy_nodes() -> list[PydanticHierarchyNode]:
]
DOC_PARENT_MAP = {
"msg-001": CHANNEL_A_ID,
"msg-002": CHANNEL_A_ID,
"msg-003": CHANNEL_B_ID,
}
def _make_slim_docs() -> list[SlimDocument | PydanticHierarchyNode]:
return [SlimDocument(id=doc_id) for doc_id in SLIM_DOC_IDS]
return [
SlimDocument(id=doc_id, parent_hierarchy_raw_node_id=DOC_PARENT_MAP.get(doc_id))
for doc_id in SLIM_DOC_IDS
]
class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
@@ -126,14 +142,31 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
# ---------------------------------------------------------------------------
def _cleanup_test_hierarchy_nodes(db_session: Session) -> None:
"""Remove all hierarchy nodes for TEST_SOURCE to isolate tests."""
def _cleanup_test_data(db_session: Session) -> None:
"""Remove all test hierarchy nodes and documents to isolate tests."""
for doc_id in SLIM_DOC_IDS:
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == TEST_SOURCE
).delete()
db_session.commit()
def _create_test_documents(db_session: Session) -> list[DbDocument]:
"""Insert minimal Document rows for our test doc IDs."""
docs = []
for doc_id in SLIM_DOC_IDS:
doc = DbDocument(
id=doc_id,
semantic_id=doc_id,
kg_stage=KGStage.NOT_STARTED,
)
db_session.add(doc)
docs.append(doc)
db_session.commit()
return docs
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@@ -147,14 +180,14 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
result = extract_ids_from_runnable_connector(connector, callback=None)
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
# (hierarchy node IDs are added to doc_ids so they aren't pruned)
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
expected_ids = {
CHANNEL_A_ID,
CHANNEL_B_ID,
CHANNEL_C_ID,
*SLIM_DOC_IDS,
}
assert result.doc_ids == expected_ids
assert result.raw_id_to_parent.keys() == expected_ids
# Hierarchy nodes should be the 3 channels
assert len(result.hierarchy_nodes) == 3
@@ -165,7 +198,7 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
def test_pruning_upserts_hierarchy_nodes_to_db(db_session: Session) -> None:
"""Full flow: extract hierarchy nodes from mock connector, upsert to Postgres,
then verify the DB state (node count, parent relationships, permissions)."""
_cleanup_test_hierarchy_nodes(db_session)
_cleanup_test_data(db_session)
# Step 1: ensure the SOURCE node exists (mirrors what the pruning task does)
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -230,7 +263,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
) -> None:
"""When the connector's access type is PUBLIC, all hierarchy nodes must be
marked is_public=True regardless of their external_access settings."""
_cleanup_test_hierarchy_nodes(db_session)
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -257,7 +290,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
"""Upserting the same hierarchy nodes twice must not create duplicates.
The second call should update existing rows in place."""
_cleanup_test_hierarchy_nodes(db_session)
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -295,7 +328,7 @@ def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> None:
"""Upserting a hierarchy node with changed fields should update the existing row."""
_cleanup_test_hierarchy_nodes(db_session)
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -342,3 +375,193 @@ def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> No
assert db_node.is_public is True
assert db_node.external_user_emails is not None
assert set(db_node.external_user_emails) == {"new_user@example.com"}
# ---------------------------------------------------------------------------
# Document-to-hierarchy-node linkage tests
# ---------------------------------------------------------------------------
def test_extraction_preserves_parent_hierarchy_raw_node_id(
db_session: Session, # noqa: ARG001
) -> None:
"""extract_ids_from_runnable_connector should carry the
parent_hierarchy_raw_node_id from SlimDocument into the raw_id_to_parent dict."""
connector = MockSlimConnectorWithPermSync()
result = extract_ids_from_runnable_connector(connector, callback=None)
for doc_id, expected_parent in DOC_PARENT_MAP.items():
assert (
result.raw_id_to_parent[doc_id] == expected_parent
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
# Hierarchy node entries have None parent (they aren't documents)
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
assert result.raw_id_to_parent[channel_id] is None
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
"""update_document_parent_hierarchy_nodes should set
Document.parent_hierarchy_node_id for each document in the mapping."""
_cleanup_test_data(db_session)
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
# Create documents with no parent set
docs = _create_test_documents(db_session)
for doc in docs:
assert doc.parent_hierarchy_node_id is None
# Build resolved map (same logic as _resolve_and_update_document_parents)
resolved: dict[str, int | None] = {}
for doc_id, raw_parent in DOC_PARENT_MAP.items():
resolved[doc_id] = node_id_by_raw.get(raw_parent, source_node.id)
updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert updated == len(SLIM_DOC_IDS)
# Verify each document now points to the correct hierarchy node
db_session.expire_all()
for doc_id, raw_parent in DOC_PARENT_MAP.items():
tmp_doc = db_session.get(DbDocument, doc_id)
assert tmp_doc is not None
doc = tmp_doc
expected_node_id = node_id_by_raw[raw_parent]
assert (
doc.parent_hierarchy_node_id == expected_node_id
), f"Document {doc_id} should point to node for {raw_parent}"
def test_update_document_parent_is_idempotent(db_session: Session) -> None:
"""Running update_document_parent_hierarchy_nodes a second time with the
same mapping should update zero rows."""
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
_create_test_documents(db_session)
resolved: dict[str, int | None] = {
doc_id: node_id_by_raw[raw_parent]
for doc_id, raw_parent in DOC_PARENT_MAP.items()
}
first_updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert first_updated == len(SLIM_DOC_IDS)
second_updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert second_updated == 0
def test_link_hierarchy_nodes_to_documents_for_confluence(
db_session: Session,
) -> None:
"""For sources in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS (e.g. Confluence),
link_hierarchy_nodes_to_documents should set HierarchyNode.document_id
when a hierarchy node's raw_node_id matches a document ID."""
_cleanup_test_data(db_session)
confluence_source = DocumentSource.CONFLUENCE
# Clean up any existing Confluence hierarchy nodes
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == confluence_source
).delete()
db_session.commit()
ensure_source_node_exists(db_session, confluence_source, commit=True)
# Create a hierarchy node whose raw_node_id matches a document ID
page_node_id = "confluence-page-123"
nodes = [
PydanticHierarchyNode(
raw_node_id=page_node_id,
raw_parent_id=None,
display_name="Test Page",
link="https://wiki.example.com/page/123",
node_type=HierarchyNodeType.PAGE,
),
]
upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=nodes,
source=confluence_source,
commit=True,
is_connector_public=False,
)
# Verify the node exists but has no document_id yet
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
assert db_node is not None
assert db_node.document_id is None
# Create a document with the same ID as the hierarchy node
doc = DbDocument(
id=page_node_id,
semantic_id="Test Page",
kg_stage=KGStage.NOT_STARTED,
)
db_session.add(doc)
db_session.commit()
# Link nodes to documents
linked = link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=[page_node_id],
source=confluence_source,
commit=True,
)
assert linked == 1
# Verify the hierarchy node now has document_id set
db_session.expire_all()
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
assert db_node is not None
assert db_node.document_id == page_node_id
# Cleanup
db_session.query(DbDocument).filter(DbDocument.id == page_node_id).delete()
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == confluence_source
).delete()
db_session.commit()
def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
db_session: Session,
) -> None:
"""link_hierarchy_nodes_to_documents should return 0 for sources that
don't support hierarchy-node-as-document (e.g. Slack, Google Drive)."""
linked = link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=SLIM_DOC_IDS,
source=TEST_SOURCE, # Slack — not in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS
commit=False,
)
assert linked == 0

View File

@@ -11,6 +11,7 @@ from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_default_contextual_rag_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import IndexModelStatus
@@ -37,6 +38,8 @@ def _create_llm_provider_and_model(
model_name: str,
) -> None:
"""Insert an LLM provider with a single visible model configuration."""
if fetch_existing_llm_provider(name=provider_name, db_session=db_session):
return
upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
@@ -146,8 +149,8 @@ def baseline_search_settings(
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.db.swap_index.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -155,6 +158,7 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
mock_get_all_doc_indices: MagicMock,
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
@@ -196,8 +200,8 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.db.swap_index.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -205,6 +209,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
mock_get_all_doc_indices: MagicMock,
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
@@ -266,7 +271,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -274,6 +279,7 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
) -> None:

View File

@@ -950,7 +950,6 @@ from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
@@ -1292,22 +1291,12 @@ def test_code_interpreter_replay_packets_include_code_and_output(
tool_call_id="call_replay_test",
tool_call_argument_tokens=[json.dumps({"code": code})],
)
).expect(
Packet(
placement=create_placement(0),
obj=ToolCallArgumentDelta(
tool_type="python",
tool_id="call_replay_test",
argument_deltas={"code": code},
),
),
forward=2,
).expect(
Packet(
placement=create_placement(0),
obj=PythonToolStart(code=code),
),
forward=False,
forward=2,
).expect(
Packet(
placement=create_placement(0),

View File

@@ -1,4 +1,3 @@
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
@@ -365,7 +364,6 @@ def test_update_contextual_rag_missing_model_name(
assert "Provider name and model name are required" in response.json()["detail"]
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_with_contextual_rag(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -394,7 +392,6 @@ def test_set_new_search_settings_with_contextual_rag(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_without_contextual_rag(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -419,7 +416,6 @@ def test_set_new_search_settings_without_contextual_rag(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_then_update_inference_settings(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -457,7 +453,6 @@ def test_set_new_then_update_inference_settings(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_replaces_previous_secondary(
reset: None, # noqa: ARG001
admin_user: DATestUser,

View File

@@ -281,9 +281,10 @@ class TestApplyLicenseStatusToSettings:
}
class TestSettingsDefaultEEDisabled:
"""Verify the Settings model defaults ee_features_enabled to False."""
class TestSettingsDefaults:
"""Verify Settings model defaults for CE deployments."""
def test_default_ee_features_disabled(self) -> None:
"""CE default: ee_features_enabled is False."""
settings = Settings()
assert settings.ee_features_enabled is False

View File

@@ -1,584 +0,0 @@
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
def _make_tool_call_delta(
index: int = 0,
tool_id: str | None = None,
name: str | None = None,
arguments: str | None = None,
function_is_none: bool = False,
) -> MagicMock:
"""Create a mock tool_call_delta matching the LiteLLM streaming shape."""
delta = MagicMock()
delta.index = index
delta.id = tool_id
if function_is_none:
delta.function = None
else:
delta.function = MagicMock()
delta.function.name = name
delta.function.arguments = arguments
return delta
def _make_placement() -> Placement:
return Placement(turn_index=0, tab_index=0)
def _mock_tool_class(emit: bool = True) -> MagicMock:
cls = MagicMock()
cls.do_emit_argument_deltas.return_value = emit
return cls
def _collect(
tc_map: dict[int, dict[str, Any]],
delta: MagicMock,
placement: Placement | None = None,
scan_offsets: dict[int, int] | None = None,
) -> list[Any]:
"""Run maybe_emit_argument_delta and return the yielded packets."""
return list(
maybe_emit_argument_delta(
tc_map,
delta,
placement or _make_placement(),
scan_offsets if scan_offsets is not None else {},
)
)
def _stream_fragments(
fragments: list[str],
tc_map: dict[int, dict[str, Any]],
placement: Placement | None = None,
) -> list[str]:
"""Feed fragments into maybe_emit_argument_delta one by one, returning
all emitted content values concatenated per-key as a flat list."""
pl = placement or _make_placement()
scan_offsets: dict[int, int] = {}
emitted: list[str] = []
for frag in fragments:
tc_map[0]["arguments"] += frag
delta = _make_tool_call_delta(arguments=frag)
for packet in maybe_emit_argument_delta(
tc_map, delta, pl, scan_offsets=scan_offsets
):
obj = packet.obj
assert isinstance(obj, ToolCallArgumentDelta)
for value in obj.argument_deltas.values():
emitted.append(value)
return emitted
class TestMaybeEmitArgumentDeltaGuards:
"""Tests for conditions that cause no packet to be emitted."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_tool_does_not_opt_in(
self, mock_get_tool: MagicMock
) -> None:
"""Tools that return False from do_emit_argument_deltas emit nothing."""
mock_get_tool.return_value = _mock_tool_class(emit=False)
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_tool_class_unknown(
self, mock_get_tool: MagicMock
) -> None:
mock_get_tool.return_value = None
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "unknown", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_no_argument_fragment(
self, mock_get_tool: MagicMock
) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments=None)) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_key_value_not_started(
self, mock_get_tool: MagicMock
) -> None:
"""Key exists in JSON but its string value hasn't begun yet."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code":'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments=":")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_before_any_key(self, mock_get_tool: MagicMock) -> None:
"""Only the opening brace has arrived — no key to stream yet."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": "{"}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="{")) == []
class TestMaybeEmitArgumentDeltaBasic:
"""Tests for correct packet content and incremental emission."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_emits_packet_with_correct_fields(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "print(1)',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments="print(1)"))
assert len(packets) == 1
obj = packets[0].obj
assert isinstance(obj, ToolCallArgumentDelta)
assert obj.tool_type == "python"
assert obj.tool_id == "tc_1"
assert obj.argument_deltas == {"code": "print(1)"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_emits_only_new_content_on_subsequent_call(
self, mock_get_tool: MagicMock
) -> None:
"""After a first emission, subsequent calls emit only the diff."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "abc'}
}
packets_1 = _collect(tc_map, _make_tool_call_delta(arguments="abc"))
assert packets_1[0].obj.argument_deltas == {"code": "abc"}
tc_map[0]["arguments"] = '{"code": "abcdef'
packets_2 = _collect(tc_map, _make_tool_call_delta(arguments="def"))
assert packets_2[0].obj.argument_deltas == {"code": "def"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_handles_multiple_keys_sequentially(self, mock_get_tool: MagicMock) -> None:
"""When a second key starts, emissions switch to that key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
packets_1 = _collect(tc_map, _make_tool_call_delta(arguments="x"))
assert packets_1[0].obj.argument_deltas == {"code": "x"}
tc_map[0]["arguments"] = '{"code": "x", "output": "hello'
packets_2 = _collect(tc_map, _make_tool_call_delta(arguments="hello"))
assert packets_2[0].obj.argument_deltas == {"output": "hello"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_spans_key_boundary(self, mock_get_tool: MagicMock) -> None:
"""A single delta contains the end of one value and the start of the next key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
packets_1 = _collect(tc_map, _make_tool_call_delta(arguments="x"))
assert packets_1[0].obj.argument_deltas == {"code": "x"}
# Delta carries closing of "code" value + opening of "lang" key + start of value
tc_map[0]["arguments"] = '{"code": "xy", "lang": "py'
packets_2 = _collect(tc_map, _make_tool_call_delta(arguments='y", "lang": "py'))
assert len(packets_2) == 1
assert packets_2[0].obj.argument_deltas == {"code": "y", "lang": "py"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_empty_value_emits_nothing(self, mock_get_tool: MagicMock) -> None:
"""An empty string value has nothing to emit."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "'}
}
# Opening quote just arrived, value is empty
assert _collect(tc_map, _make_tool_call_delta(arguments='"')) == []
class TestMaybeEmitArgumentDeltaDecoding:
"""Tests verifying that JSON escape sequences are properly decoded."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_newlines(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "line1\\nline2',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments="line1\\nline2"))
assert packets[0].obj.argument_deltas == {"code": "line1\nline2"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_tabs(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "\\tindented',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments="\\tindented"))
assert packets[0].obj.argument_deltas == {"code": "\tindented"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_escaped_quotes(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "say \\"hi\\"',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments='say \\"hi\\"'))
assert packets[0].obj.argument_deltas == {"code": 'say "hi"'}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_escaped_backslashes(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "path\\\\dir',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments="path\\\\dir"))
assert packets[0].obj.argument_deltas == {"code": "path\\dir"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_unicode_escape(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "\\u0041',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments="\\u0041"))
assert packets[0].obj.argument_deltas == {"code": "A"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_incomplete_escape_at_end_trims_safely(
self, mock_get_tool: MagicMock
) -> None:
"""A trailing backslash (incomplete escape) is handled gracefully."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "hello\\',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments="hello\\"))
# "hello" can be decoded; the trailing backslash is trimmed
assert packets[0].obj.argument_deltas == {"code": "hello"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_incomplete_unicode_escape_trims_safely(
self, mock_get_tool: MagicMock
) -> None:
"""A partial \\uXX sequence is trimmed, emitting what can be decoded."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "hello\\u00',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments="hello\\u00"))
assert packets[0].obj.argument_deltas == {"code": "hello"}
class TestArgumentDeltaStreamingE2E:
"""Simulates realistic sequences of LLM argument deltas to verify
the full pipeline produces correct decoded output."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_realistic_python_code_streaming(self, mock_get_tool: MagicMock) -> None:
"""Streams: {"code": "print('hello')\\nprint('world')"}"""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"',
"code",
'": "',
"print(",
"'hello')",
"\\n",
"print(",
"'world')",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "print('hello')\nprint('world')"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_streaming_with_tabs_and_newlines(self, mock_get_tool: MagicMock) -> None:
"""Streams code with tabs and newlines."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"if True:",
"\\n",
"\\t",
"pass",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "if True:\n\tpass"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_split_escape_sequence(self, mock_get_tool: MagicMock) -> None:
"""An escape sequence split across two fragments (backslash in one,
'n' in the next) should still decode correctly."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "hello',
"\\",
"n",
'world"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "hello\nworld"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_multiple_newlines_and_indentation(self, mock_get_tool: MagicMock) -> None:
"""Streams a multi-line function with multiple escape sequences."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"def foo():",
"\\n",
"\\t",
"x = 1",
"\\n",
"\\t",
"return x",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "def foo():\n\tx = 1\n\treturn x"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_two_keys_streamed_sequentially(self, mock_get_tool: MagicMock) -> None:
"""Streams code first, then a second key (language) — both decoded."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"x = 1",
'", "language": "',
"python",
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
# Should have emissions for both keys
full = "".join(emitted)
assert "x = 1" in full
assert "python" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_code_containing_dict_literal(self, mock_get_tool: MagicMock) -> None:
"""Python code like `x = {"key": "val"}` contains JSON-like patterns.
The escaped quotes inside the *outer* JSON value should prevent the
inner `"key":` from being mistaken for a top-level JSON key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
# The LLM sends: {"code": "x = {\"key\": \"val\"}"}
# The inner quotes are escaped as \" in the JSON value.
fragments = [
'{"code": "',
"x = {",
'\\"key\\"',
": ",
'\\"val\\"',
"}",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == 'x = {"key": "val"}'
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_code_with_colon_in_value(self, mock_get_tool: MagicMock) -> None:
"""Colons inside the string value should not confuse key detection."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"url = ",
'\\"https://example.com\\"',
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == 'url = "https://example.com"'
class TestMaybeEmitArgumentDeltaEdgeCases:
"""Edge cases not covered by the standard test classes."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_function_is_none(self, mock_get_tool: MagicMock) -> None:
"""Some delta chunks have function=None (e.g. role-only deltas)."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
delta = _make_tool_call_delta(arguments=None, function_is_none=True)
assert _collect(tc_map, delta) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_multiple_concurrent_tool_calls(self, mock_get_tool: MagicMock) -> None:
"""Two tool calls streaming at different indices in parallel."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "aaa'},
1: {"id": "tc_2", "name": "python", "arguments": '{"code": "bbb'},
}
# Delta for index 0
packets_0 = _collect(tc_map, _make_tool_call_delta(index=0, arguments="aaa"))
assert len(packets_0) == 1
assert packets_0[0].obj.tool_id == "tc_1"
assert packets_0[0].obj.argument_deltas == {"code": "aaa"}
# Delta for index 1
packets_1 = _collect(tc_map, _make_tool_call_delta(index=1, arguments="bbb"))
assert len(packets_1) == 1
assert packets_1[0].obj.tool_id == "tc_2"
assert packets_1[0].obj.argument_deltas == {"code": "bbb"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_with_four_arguments(self, mock_get_tool: MagicMock) -> None:
"""A single delta contains four complete key-value pairs."""
mock_get_tool.return_value = _mock_tool_class()
accumulated = '{"a": "one", "b": "two", "c": "three", "d": "four'
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": accumulated}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments=accumulated))
assert len(packets) == 1
assert packets[0].obj.argument_deltas == {
"a": "one",
"b": "two",
"c": "three",
"d": "four",
}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_on_second_arg_after_first_complete(
self, mock_get_tool: MagicMock
) -> None:
"""First argument is fully complete; delta only adds to the second."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"code": "print(1)", "lang": "py',
}
}
packets = _collect(tc_map, _make_tool_call_delta(arguments="py"))
assert len(packets) == 1
assert packets[0].obj.argument_deltas == {"lang": "py"}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_non_string_values_skipped(self, mock_get_tool: MagicMock) -> None:
"""Non-string values (numbers, booleans, null) are skipped — they are
available in the final tool-call kickoff packet. String arguments
following them are still emitted."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {
"id": "tc_1",
"name": "python",
"arguments": '{"timeout": 30, "code": "hello',
}
}
packets = _collect(
tc_map, _make_tool_call_delta(arguments='30, "code": "hello')
)
assert len(packets) == 1
assert packets[0].obj.argument_deltas == {"code": "hello"}

View File

@@ -1,11 +1,12 @@
"use client";
import React, { useEffect, useRef, useState } from "react";
import React, { useEffect, useMemo, useRef, useState } from "react";
import { FileDescriptor } from "@/app/app/interfaces";
import "katex/dist/katex.min.css";
import MessageSwitcher from "@/app/app/message/MessageSwitcher";
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import useScreenSize from "@/hooks/useScreenSize";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import { Button } from "@opal/components";
import { SvgEdit } from "@opal/icons";
@@ -137,6 +138,7 @@ const HumanMessage = React.memo(function HumanMessage({
const [content, setContent] = useState(initialContent);
const [isEditing, setIsEditing] = useState(false);
const { isMobile } = useScreenSize();
// Use nodeId for switching (finding position in siblings)
const indexInSiblings = otherMessagesCanSwitchTo?.indexOf(nodeId);
@@ -168,119 +170,104 @@ const HumanMessage = React.memo(function HumanMessage({
return undefined;
};
const copyEditButton = useMemo(
() => (
<div className="flex flex-row flex-shrink px-1 opacity-0 group-hover:opacity-100 transition-opacity">
<CopyIconButton
getCopyText={() => content}
prominence="tertiary"
data-testid="HumanMessage/copy-button"
/>
<Button
icon={SvgEdit}
prominence="tertiary"
tooltip="Edit"
onClick={() => setIsEditing(true)}
data-testid="HumanMessage/edit-button"
/>
</div>
),
[content]
);
return (
<div
id="onyx-human-message"
className="group flex flex-col justify-end w-full relative"
>
<FileDisplay alignBubble files={files || []} />
<div className="md:flex md:flex-wrap relative justify-end break-words">
{isEditing ? (
<MessageEditing
content={content}
onSubmitEdit={(editedContent) => {
// Don't update UI for edits that can't be persisted
if (messageId === undefined || messageId === null) {
setIsEditing(false);
return;
}
onEdit?.(editedContent, messageId);
setContent(editedContent);
{isEditing ? (
<MessageEditing
content={content}
onSubmitEdit={(editedContent) => {
// Don't update UI for edits that can't be persisted
if (messageId === undefined || messageId === null) {
setIsEditing(false);
}}
onCancelEdit={() => setIsEditing(false)}
/>
) : typeof content === "string" ? (
<>
<div className="md:max-w-[37.5rem] flex basis-[100%] md:basis-auto justify-end md:order-1">
<div
className={
"max-w-[30rem] md:max-w-[37.5rem] whitespace-break-spaces break-anywhere rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
}
onCopy={(e) => {
const selection = window.getSelection();
if (selection) {
e.preventDefault();
const text = selection
.toString()
.replace(/\n{2,}/g, "\n")
.trim();
e.clipboardData.setData("text/plain", text);
}
}}
>
<Text
as="p"
className="inline-block align-middle"
mainContentBody
>
{content}
</Text>
</div>
</div>
{onEdit && !isEditing && (
<div className="absolute md:relative right-0 z-content flex flex-row p-1 opacity-0 group-hover:opacity-100 transition-opacity">
<CopyIconButton
getCopyText={() => content}
prominence="tertiary"
data-testid="HumanMessage/copy-button"
/>
<Button
icon={SvgEdit}
prominence="tertiary"
tooltip="Edit"
onClick={() => setIsEditing(true)}
data-testid="HumanMessage/edit-button"
/>
</div>
)}
</>
) : (
<>
return;
}
onEdit?.(editedContent, messageId);
setContent(editedContent);
setIsEditing(false);
}}
onCancelEdit={() => setIsEditing(false)}
/>
) : (
<div className="flex justify-end">
{onEdit && !isMobile && copyEditButton}
<div className="md:max-w-[37.5rem]">
<div
className={cn(
"my-auto",
onEdit && !isEditing
? "opacity-0 group-hover:opacity-100 transition-opacity"
: "invisible"
)}
className={
"max-w-[30rem] md:max-w-[37.5rem] whitespace-break-spaces break-anywhere rounded-t-16 rounded-bl-16 bg-background-tint-02 py-2 px-3"
}
onCopy={(e) => {
const selection = window.getSelection();
if (selection) {
e.preventDefault();
const text = selection
.toString()
.replace(/\n{2,}/g, "\n")
.trim();
e.clipboardData.setData("text/plain", text);
}
}}
>
<Button
icon={SvgEdit}
onClick={() => setIsEditing(true)}
prominence="tertiary"
tooltip="Edit"
/>
<Text
as="p"
className="inline-block align-middle"
mainContentBody
>
{content}
</Text>
</div>
<div className="ml-auto rounded-lg p-1">{content}</div>
</>
)}
<div className="md:min-w-[100%] flex justify-end order-1 mt-1">
{currentMessageInd !== undefined &&
onMessageSelection &&
otherMessagesCanSwitchTo &&
otherMessagesCanSwitchTo.length > 1 && (
<MessageSwitcher
disableForStreaming={disableSwitchingForStreaming}
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {
stopGenerating();
const prevMessage = getPreviousMessage();
if (prevMessage !== undefined) {
onMessageSelection(prevMessage);
}
}}
handleNext={() => {
stopGenerating();
const nextMessage = getNextMessage();
if (nextMessage !== undefined) {
onMessageSelection(nextMessage);
}
}}
/>
)}
</div>
</div>
)}
<div className="flex justify-end pt-1">
{!isEditing && onEdit && isMobile && copyEditButton}
{currentMessageInd !== undefined &&
onMessageSelection &&
otherMessagesCanSwitchTo &&
otherMessagesCanSwitchTo.length > 1 && (
<MessageSwitcher
disableForStreaming={disableSwitchingForStreaming}
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {
stopGenerating();
const prevMessage = getPreviousMessage();
if (prevMessage !== undefined) {
onMessageSelection(prevMessage);
}
}}
handleNext={() => {
stopGenerating();
const nextMessage = getNextMessage();
if (nextMessage !== undefined) {
onMessageSelection(nextMessage);
}
}}
/>
)}
</div>
</div>
);

View File

@@ -0,0 +1,18 @@
"use client";
import { useEffect } from "react";
import { useRouter } from "next/navigation";
import { toast } from "@/hooks/useToast";
export default function EEFeatureRedirect() {
const router = useRouter();
useEffect(() => {
toast.error(
"This feature requires a license. Please upgrade your plan to access."
);
router.replace("/app");
}, [router]);
return null;
}

View File

@@ -1,5 +1,6 @@
import { SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED } from "@/lib/constants";
import { fetchStandardSettingsSS } from "@/components/settings/lib";
import EEFeatureRedirect from "@/app/ee/EEFeatureRedirect";
export default async function AdminLayout({
children,
@@ -8,13 +9,7 @@ export default async function AdminLayout({
}) {
// First check build-time constant (fast path)
if (!SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED) {
return (
<div className="flex h-screen">
<div className="mx-auto my-auto text-lg font-bold text-red-500">
This functionality is only available in the Enterprise Edition :(
</div>
</div>
);
return <EEFeatureRedirect />;
}
// Then check runtime license status (for license enforcement mode)
@@ -31,13 +26,7 @@ export default async function AdminLayout({
return children;
}
return (
<div className="flex h-screen">
<div className="mx-auto my-auto text-lg font-bold text-red-500">
This functionality requires an active Enterprise license.
</div>
</div>
);
return <EEFeatureRedirect />;
}
}
} catch (error) {

View File

@@ -484,12 +484,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
ref={chatInputBarRef}
deepResearchEnabled={deepResearchEnabled}
toggleDeepResearch={toggleDeepResearch}
toggleDocumentSidebar={() => {}}
filterManager={filterManager}
llmManager={llmManager}
removeDocs={() => {}}
retrievalEnabled={retrievalEnabled}
selectedDocuments={[]}
initialMessage={message}
stopGenerating={stopGenerating}
onSubmit={handleChatInputSubmit}

View File

@@ -23,8 +23,7 @@ export interface AppModeProviderProps {
export function AppModeProvider({ children }: AppModeProviderProps) {
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const { user } = useUser();
const settings = useSettingsContext();
const { isSearchModeAvailable } = settings;
const { isSearchModeAvailable } = useSettingsContext();
const persistedMode = user?.preferences?.default_app_mode;
const [appMode, setAppModeState] = useState<AppMode>("chat");

View File

@@ -11,21 +11,8 @@ import {
* Hook to fetch billing information from Stripe.
*
* Works for both cloud and self-hosted deployments:
* - Cloud: fetches from /api/tenants/billing-information (legacy endpoint)
* - Cloud: fetches from /api/tenants/billing-information
* - Self-hosted: fetches from /api/admin/billing/billing-information
*
* Returns subscription status, seats, billing period, etc.
*
* @example
* ```tsx
* const { data, isLoading, error, refresh } = useBillingInformation();
*
* if (isLoading) return <Loading />;
* if (error) return <Error />;
* if (!data || !hasActiveSubscription(data)) return <NoSubscription />;
*
* return <BillingDetails billing={data} />;
* ```
*/
export function useBillingInformation() {
const url = NEXT_PUBLIC_CLOUD_ENABLED
@@ -38,16 +25,9 @@ export function useBillingInformation() {
revalidateOnFocus: false,
revalidateOnReconnect: false,
dedupingInterval: 30000,
// Don't auto-retry on errors (circuit breaker will block requests anyway)
shouldRetryOnError: false,
// Keep previous data while revalidating to prevent UI flashing
keepPreviousData: true,
});
return {
data,
isLoading,
error,
refresh: mutate,
};
return { data, isLoading, error, refresh: mutate };
}

View File

@@ -7,23 +7,9 @@ import { LicenseStatus } from "@/lib/billing/interfaces";
/**
* Hook to fetch license status for self-hosted deployments.
*
* Returns license information including seats, expiry, and status.
* Only fetches for self-hosted deployments (cloud uses tenant auth instead).
*
* @example
* ```tsx
* const { data, isLoading, error, refresh } = useLicense();
*
* if (isLoading) return <Loading />;
* if (error) return <Error />;
* if (!data?.has_license) return <NoLicense />;
*
* return <LicenseDetails license={data} />;
* ```
* Skips the fetch on cloud deployments (uses tenant auth instead).
*/
export function useLicense() {
// Only fetch license for self-hosted deployments
// Cloud deployments use tenant-based auth, not license files
const url = NEXT_PUBLIC_CLOUD_ENABLED ? null : "/api/license";
const { data, error, mutate, isLoading } = useSWR<LicenseStatus>(
@@ -38,20 +24,14 @@ export function useLicense() {
}
);
// Return empty state for cloud deployments
if (NEXT_PUBLIC_CLOUD_ENABLED) {
if (!url) {
return {
data: null,
data: undefined,
isLoading: false,
error: undefined,
refresh: () => Promise.resolve(undefined),
};
}
return {
data,
isLoading,
error,
refresh: mutate,
};
return { data, isLoading, error, refresh: mutate };
}

View File

@@ -46,8 +46,8 @@ export interface Settings {
// Onyx Craft (Build Mode) feature flag
onyx_craft_enabled?: boolean;
// Enterprise features flag - controlled by license enforcement at runtime
// True when user has a valid license, False for community edition
// Whether EE features are unlocked (user has a valid enterprise license).
// Controls UI visibility of EE features like user groups, analytics, RBAC.
ee_features_enabled?: boolean;
// Seat usage - populated when seat limit is exceeded

View File

@@ -190,14 +190,17 @@ function AttachmentItemLayout({
alignItems="center"
gap={1.5}
>
<Content
title={title}
description={description}
sizePreset="main-ui"
variant="section"
/>
<div className="flex-1 min-w-0">
<Content
title={title}
description={description}
sizePreset="main-ui"
variant="section"
widthVariant="full"
/>
</div>
{middleText && (
<div className="flex-1">
<div className="flex-1 min-w-0">
<Truncated text03 secondaryBody>
{middleText}
</Truncated>

View File

@@ -42,8 +42,13 @@ export const NEXT_PUBLIC_CUSTOM_REFRESH_URL =
// NOTE: this should ONLY be used on the server-side. If used client side,
// it will not be accurate (will always be false).
// Mirrors backend logic: EE is enabled if EITHER the legacy flag OR license
// enforcement is active. LICENSE_ENFORCEMENT_ENABLED defaults to true on the
// backend, so we treat undefined as enabled here to match.
export const SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED =
process.env.ENABLE_PAID_ENTERPRISE_EDITION_FEATURES?.toLowerCase() === "true";
process.env.ENABLE_PAID_ENTERPRISE_EDITION_FEATURES?.toLowerCase() ===
"true" ||
process.env.LICENSE_ENFORCEMENT_ENABLED?.toLowerCase() !== "false";
// NOTE: since this is a `NEXT_PUBLIC_` variable, it will be set at
// build-time
// TODO: consider moving this to an API call so that the api_server

View File

@@ -51,16 +51,6 @@ function ToastContainer() {
}, ANIMATION_DURATION);
}, []);
// NOTE (@raunakab):
//
// Keep this here for debugging purposes.
// useOnMount(() => {
// toast.success("Test success toast", { duration: Infinity });
// toast.error("Test error toast", { duration: Infinity });
// toast.warning("Test warning toast", { duration: Infinity });
// toast.info("Test info toast", { duration: Infinity });
// });
if (visible.length === 0) return null;
return (

View File

@@ -0,0 +1,106 @@
"use client";
import { useState } from "react";
import {
type Table,
type ColumnDef,
type RowData,
type VisibilityState,
} from "@tanstack/react-table";
import { Button } from "@opal/components";
import { SvgColumn, SvgCheck } from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import LineItem from "@/refresh-components/buttons/LineItem";
import Divider from "@/refresh-components/Divider";
// ---------------------------------------------------------------------------
// Popover UI
// ---------------------------------------------------------------------------
interface ColumnVisibilityPopoverProps<TData extends RowData = RowData> {
table: Table<TData>;
columnVisibility: VisibilityState;
size?: "regular" | "small";
}
function ColumnVisibilityPopover<TData extends RowData>({
table,
columnVisibility,
size = "regular",
}: ColumnVisibilityPopoverProps<TData>) {
const [open, setOpen] = useState(false);
const hideableColumns = table
.getAllLeafColumns()
.filter((col) => col.getCanHide());
return (
<Popover open={open} onOpenChange={setOpen}>
<Popover.Trigger asChild>
<Button
icon={SvgColumn}
transient={open}
size={size === "small" ? "sm" : "md"}
prominence="internal"
tooltip="Columns"
/>
</Popover.Trigger>
<Popover.Content width="lg" align="end" side="bottom">
<Divider showTitle text="Shown Columns" />
<Popover.Menu>
{hideableColumns.map((column) => {
const isVisible = columnVisibility[column.id] !== false;
const label =
typeof column.columnDef.header === "string"
? column.columnDef.header
: column.id;
return (
<LineItem
key={column.id}
selected={isVisible}
emphasized
rightChildren={isVisible ? <SvgCheck size={16} /> : undefined}
onClick={() => {
column.toggleVisibility();
}}
>
{label}
</LineItem>
);
})}
</Popover.Menu>
</Popover.Content>
</Popover>
);
}
// ---------------------------------------------------------------------------
// Column definition factory
// ---------------------------------------------------------------------------
interface CreateColumnVisibilityColumnOptions {
size?: "regular" | "small";
}
function createColumnVisibilityColumn<TData>(
options?: CreateColumnVisibilityColumnOptions
): ColumnDef<TData, unknown> {
return {
id: "__columnVisibility",
size: 44,
enableHiding: false,
enableSorting: false,
enableResizing: false,
header: ({ table }) => (
<ColumnVisibilityPopover
table={table}
columnVisibility={table.getState().columnVisibility}
size={options?.size}
/>
),
cell: () => null,
};
}
export { ColumnVisibilityPopover, createColumnVisibilityColumn };

View File

@@ -0,0 +1,455 @@
"use client";
"use no memo";
import { useMemo } from "react";
import { flexRender } from "@tanstack/react-table";
import useDataTable, {
toOnyxSortDirection,
} from "@/refresh-components/table/hooks/useDataTable";
import useColumnWidths from "@/refresh-components/table/hooks/useColumnWidths";
import useDraggableRows from "@/refresh-components/table/hooks/useDraggableRows";
import Table from "@/refresh-components/table/Table";
import TableHeader from "@/refresh-components/table/TableHeader";
import TableBody from "@/refresh-components/table/TableBody";
import TableRow from "@/refresh-components/table/TableRow";
import TableHead from "@/refresh-components/table/TableHead";
import TableCell from "@/refresh-components/table/TableCell";
import TableQualifier from "@/refresh-components/table/TableQualifier";
import QualifierContainer from "@/refresh-components/table/QualifierContainer";
import ActionsContainer from "@/refresh-components/table/ActionsContainer";
import DragOverlayRow from "@/refresh-components/table/DragOverlayRow";
import Footer from "@/refresh-components/table/Footer";
import { TableSizeProvider } from "@/refresh-components/table/TableSizeContext";
import { ColumnVisibilityPopover } from "@/refresh-components/table/ColumnVisibilityPopover";
import { SortingPopover } from "@/refresh-components/table/SortingPopover";
import type { WidthConfig } from "@/refresh-components/table/hooks/useColumnWidths";
import type { ColumnDef } from "@tanstack/react-table";
import type {
DataTableProps,
DataTableFooterConfig,
OnyxColumnDef,
OnyxDataColumn,
OnyxQualifierColumn,
OnyxActionsColumn,
} from "@/refresh-components/table/types";
import type { TableSize } from "@/refresh-components/table/TableSizeContext";
const noopGetRowId = () => "";
// ---------------------------------------------------------------------------
// Internal: resolve size-dependent widths and build TanStack columns
// ---------------------------------------------------------------------------
interface ProcessedColumns<TData> {
tanstackColumns: ColumnDef<TData, any>[];
widthConfig: WidthConfig;
qualifierColumn: OnyxQualifierColumn<TData> | null;
/** Map from column ID → OnyxColumnDef for dispatch in render loops. */
columnKindMap: Map<string, OnyxColumnDef<TData>>;
}
function processColumns<TData>(
columns: OnyxColumnDef<TData>[],
size: TableSize
): ProcessedColumns<TData> {
const tanstackColumns: ColumnDef<TData, any>[] = [];
const fixedColumnIds = new Set<string>();
const columnWeights: Record<string, number> = {};
const columnMinWidths: Record<string, number> = {};
const columnKindMap = new Map<string, OnyxColumnDef<TData>>();
let qualifierColumn: OnyxQualifierColumn<TData> | null = null;
for (const col of columns) {
const resolvedWidth =
typeof col.width === "function" ? col.width(size) : col.width;
// Clone def to avoid mutating the caller's column definitions
const clonedDef: ColumnDef<TData, any> = {
...col.def,
id: col.id,
size:
"fixed" in resolvedWidth ? resolvedWidth.fixed : resolvedWidth.weight,
};
tanstackColumns.push(clonedDef);
const id = col.id;
columnKindMap.set(id, col);
if ("fixed" in resolvedWidth) {
fixedColumnIds.add(id);
} else {
columnWeights[id] = resolvedWidth.weight;
columnMinWidths[id] = resolvedWidth.minWidth ?? 50;
}
if (col.kind === "qualifier") qualifierColumn = col;
}
return {
tanstackColumns,
widthConfig: { fixedColumnIds, columnWeights, columnMinWidths },
qualifierColumn,
columnKindMap,
};
}
// ---------------------------------------------------------------------------
// DataTable component
// ---------------------------------------------------------------------------
/**
* Config-driven table component that wires together `useDataTable`,
* `useColumnWidths`, and `useDraggableRows` automatically.
*
* Full flexibility via the column definitions from `createTableColumns()`.
*
* @example
* ```tsx
* const tc = createTableColumns<TeamMember>();
* const columns = [
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
* tc.column("email", { header: "Email", weight: 28 }),
* tc.actions(),
* ];
*
* <DataTable data={items} columns={columns} footer={{ mode: "selection" }} />
* ```
*/
export default function DataTable<TData>(props: DataTableProps<TData>) {
const {
data,
columns,
pageSize,
initialSorting,
initialColumnVisibility,
draggable,
footer,
size = "regular",
onRowClick,
height,
headerBackground,
} = props;
const effectivePageSize = pageSize ?? (footer ? 10 : data.length);
// 1. Process columns (memoized on columns + size)
const { tanstackColumns, widthConfig, qualifierColumn, columnKindMap } =
useMemo(() => processColumns(columns, size), [columns, size]);
// 2. Call useDataTable
const {
table,
currentPage,
totalPages,
totalItems,
setPage,
pageSize: resolvedPageSize,
selectionState,
selectedCount,
clearSelection,
toggleAllPageRowsSelected,
isAllPageRowsSelected,
} = useDataTable({
data,
columns: tanstackColumns,
pageSize: effectivePageSize,
initialSorting,
initialColumnVisibility,
});
// 3. Call useColumnWidths
const { containerRef, columnWidths, createResizeHandler } = useColumnWidths({
headers: table.getHeaderGroups()[0]?.headers ?? [],
...widthConfig,
});
// 4. Call useDraggableRows (conditional)
const draggableReturn = useDraggableRows({
data,
getRowId: draggable?.getRowId ?? noopGetRowId,
enabled: !!draggable && table.getState().sorting.length === 0,
onReorder: draggable?.onReorder,
});
const hasDraggable = !!draggable;
const rowVariant = hasDraggable ? "table" : "list";
const isSelectable =
qualifierColumn != null && qualifierColumn.selectable !== false;
// ---------------------------------------------------------------------------
// Render
// ---------------------------------------------------------------------------
function renderContent() {
return (
<div>
<div
className="overflow-x-auto"
ref={containerRef}
style={{
...(height != null
? {
maxHeight:
typeof height === "number" ? `${height}px` : height,
overflowY: "auto" as const,
}
: undefined),
...(headerBackground
? ({
"--table-header-bg": headerBackground,
} as React.CSSProperties)
: undefined),
}}
>
<Table>
<TableHeader>
{table.getHeaderGroups().map((headerGroup) => (
<TableRow key={headerGroup.id}>
{headerGroup.headers.map((header, headerIndex) => {
const colDef = columnKindMap.get(header.id);
// Qualifier header
if (colDef?.kind === "qualifier") {
if (qualifierColumn?.header === false) {
return (
<QualifierContainer key={header.id} type="head" />
);
}
return (
<QualifierContainer key={header.id} type="head">
<TableQualifier
content={
qualifierColumn?.headerContentType ?? "simple"
}
selectable={isSelectable}
selected={isSelectable && isAllPageRowsSelected}
onSelectChange={
isSelectable
? (checked) =>
toggleAllPageRowsSelected(checked)
: undefined
}
/>
</QualifierContainer>
);
}
// Actions header
if (colDef?.kind === "actions") {
const actionsDef = colDef as OnyxActionsColumn<TData>;
return (
<ActionsContainer key={header.id} type="head">
{actionsDef.showColumnVisibility !== false && (
<ColumnVisibilityPopover
table={table}
columnVisibility={
table.getState().columnVisibility
}
size={size}
/>
)}
{actionsDef.showSorting !== false && (
<SortingPopover
table={table}
sorting={table.getState().sorting}
size={size}
footerText={actionsDef.sortingFooterText}
/>
)}
</ActionsContainer>
);
}
// Data / Display header
const canSort = header.column.getCanSort();
const sortDir = header.column.getIsSorted();
const nextHeader = headerGroup.headers[headerIndex + 1];
const canResize =
header.column.getCanResize() &&
!!nextHeader &&
!widthConfig.fixedColumnIds.has(nextHeader.id);
const dataCol =
colDef?.kind === "data"
? (colDef as OnyxDataColumn<TData>)
: null;
return (
<TableHead
key={header.id}
width={columnWidths[header.id]}
sorted={
canSort ? toOnyxSortDirection(sortDir) : undefined
}
onSort={
canSort
? () => header.column.toggleSorting()
: undefined
}
icon={dataCol?.icon}
resizable={canResize}
onResizeStart={
canResize
? createResizeHandler(header.id, nextHeader.id)
: undefined
}
>
{flexRender(
header.column.columnDef.header,
header.getContext()
)}
</TableHead>
);
})}
</TableRow>
))}
</TableHeader>
<TableBody
dndSortable={hasDraggable ? draggableReturn : undefined}
renderDragOverlay={
hasDraggable
? (activeId) => {
const row = table
.getRowModel()
.rows.find(
(r) => draggable!.getRowId(r.original) === activeId
);
if (!row) return null;
return <DragOverlayRow row={row} variant={rowVariant} />;
}
: undefined
}
>
{table.getRowModel().rows.map((row) => {
const rowId = hasDraggable
? draggable!.getRowId(row.original)
: undefined;
return (
<TableRow
key={row.id}
variant={rowVariant}
sortableId={rowId}
selected={row.getIsSelected()}
onClick={() => {
if (onRowClick) {
onRowClick(row.original);
} else if (isSelectable) {
row.toggleSelected();
}
}}
>
{row.getVisibleCells().map((cell) => {
const cellColDef = columnKindMap.get(cell.column.id);
// Qualifier cell
if (cellColDef?.kind === "qualifier") {
const qDef = cellColDef as OnyxQualifierColumn<TData>;
return (
<QualifierContainer
key={cell.id}
type="cell"
onClick={(e) => e.stopPropagation()}
>
<TableQualifier
content={qDef.content}
initials={qDef.getInitials?.(row.original)}
icon={qDef.getIcon?.(row.original)}
imageSrc={qDef.getImageSrc?.(row.original)}
selectable={isSelectable}
selected={isSelectable && row.getIsSelected()}
onSelectChange={
isSelectable
? (checked) => {
row.toggleSelected(checked);
}
: undefined
}
/>
</QualifierContainer>
);
}
// Actions cell
if (cellColDef?.kind === "actions") {
return (
<ActionsContainer key={cell.id} type="cell">
{flexRender(
cell.column.columnDef.cell,
cell.getContext()
)}
</ActionsContainer>
);
}
// Data / Display cell
return (
<TableCell key={cell.id}>
{flexRender(
cell.column.columnDef.cell,
cell.getContext()
)}
</TableCell>
);
})}
</TableRow>
);
})}
</TableBody>
</Table>
</div>
{footer && renderFooter(footer)}
</div>
);
}
function renderFooter(footerConfig: DataTableFooterConfig) {
if (footerConfig.mode === "selection") {
return (
<Footer
mode="selection"
multiSelect={footerConfig.multiSelect !== false}
selectionState={selectionState}
selectedCount={selectedCount}
onClear={footerConfig.onClear ?? clearSelection}
onView={footerConfig.onView}
pageSize={resolvedPageSize}
totalItems={totalItems}
currentPage={currentPage}
totalPages={totalPages}
onPageChange={setPage}
/>
);
}
// Summary mode
const rangeStart =
totalItems === 0
? 0
: !isFinite(resolvedPageSize)
? 1
: (currentPage - 1) * resolvedPageSize + 1;
const rangeEnd = !isFinite(resolvedPageSize)
? totalItems
: Math.min(currentPage * resolvedPageSize, totalItems);
return (
<Footer
mode="summary"
rangeStart={rangeStart}
rangeEnd={rangeEnd}
totalItems={totalItems}
currentPage={currentPage}
totalPages={totalPages}
onPageChange={setPage}
/>
);
}
return <TableSizeProvider size={size}>{renderContent()}</TableSizeProvider>;
}

View File

@@ -0,0 +1,260 @@
"use client";
import { cn } from "@/lib/utils";
import { Button } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
import Pagination from "@/refresh-components/table/Pagination";
import { useTableSize } from "@/refresh-components/table/TableSizeContext";
import type { TableSize } from "@/refresh-components/table/TableSizeContext";
import { SvgEye, SvgXCircle } from "@opal/icons";
type SelectionState = "none" | "partial" | "all";
/**
* Footer mode for tables with selectable rows.
* Displays a selection message on the left (with optional view/clear actions)
* and a `count`-type pagination on the right.
*/
interface FooterSelectionModeProps {
mode: "selection";
/** Whether the table supports selecting multiple rows. */
multiSelect: boolean;
/** Current selection state: `"none"`, `"partial"`, or `"all"`. */
selectionState: SelectionState;
/** Number of currently selected items. */
selectedCount: number;
/** If provided, renders a "View" icon button when items are selected. */
onView?: () => void;
/** If provided, renders a "Clear" icon button when items are selected. */
onClear?: () => void;
/** Number of items displayed per page. */
pageSize: number;
/** Total number of items across all pages. */
totalItems: number;
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
size?: TableSize;
className?: string;
}
/**
* Footer mode for read-only tables (no row selection).
* Displays "Showing X~Y of Z" on the left and a `list`-type pagination
* on the right.
*/
interface FooterSummaryModeProps {
mode: "summary";
/** First item number in the current page (e.g. `1`). */
rangeStart: number;
/** Last item number in the current page (e.g. `25`). */
rangeEnd: number;
/** Total number of items across all pages. */
totalItems: number;
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** Controls overall footer sizing. `"regular"` (default) or `"small"`. */
size?: TableSize;
className?: string;
}
/**
* Discriminated union of footer modes.
* Use `mode: "selection"` for tables with selectable rows, or
* `mode: "summary"` for read-only tables.
*/
export type FooterProps = FooterSelectionModeProps | FooterSummaryModeProps;
function getSelectionMessage(
state: SelectionState,
multi: boolean,
count: number
): string {
if (state === "none") {
return multi ? "Select items to continue" : "Select an item to continue";
}
if (!multi) return "Item selected";
return `${count} item${count !== 1 ? "s" : ""} selected`;
}
/**
* Table footer combining status information on the left with pagination on the
* right. Use `mode: "selection"` for tables with selectable rows, or
* `mode: "summary"` for read-only tables.
*/
export default function Footer(props: FooterProps) {
const contextSize = useTableSize();
const resolvedSize = props.size ?? contextSize;
const isSmall = resolvedSize === "small";
return (
<div
className={cn(
"table-footer",
"flex w-full items-center justify-between border-t border-border-01",
props.className
)}
data-size={resolvedSize}
>
{/* Left side */}
<div className="flex items-center gap-1 px-1">
{props.mode === "selection" ? (
<SelectionLeft
selectionState={props.selectionState}
multiSelect={props.multiSelect}
selectedCount={props.selectedCount}
onView={props.onView}
onClear={props.onClear}
isSmall={isSmall}
/>
) : (
<SummaryLeft
rangeStart={props.rangeStart}
rangeEnd={props.rangeEnd}
totalItems={props.totalItems}
isSmall={isSmall}
/>
)}
</div>
{/* Right side */}
<div className="flex items-center gap-2 px-1 py-2">
{props.mode === "selection" ? (
<Pagination
type="count"
pageSize={props.pageSize}
totalItems={props.totalItems}
currentPage={props.currentPage}
totalPages={props.totalPages}
onPageChange={props.onPageChange}
showUnits
size={isSmall ? "sm" : "md"}
/>
) : (
<Pagination
type="list"
currentPage={props.currentPage}
totalPages={props.totalPages}
onPageChange={props.onPageChange}
size={isSmall ? "md" : "lg"}
/>
)}
</div>
</div>
);
}
interface SelectionLeftProps {
selectionState: SelectionState;
multiSelect: boolean;
selectedCount: number;
onView?: () => void;
onClear?: () => void;
isSmall: boolean;
}
function SelectionLeft({
selectionState,
multiSelect,
selectedCount,
onView,
onClear,
isSmall,
}: SelectionLeftProps) {
const message = getSelectionMessage(
selectionState,
multiSelect,
selectedCount
);
const hasSelection = selectionState !== "none";
return (
<div className="flex flex-row gap-1 items-center justify-center w-fit flex-shrink-0 h-fit px-1">
{isSmall ? (
<Text
secondaryAction={hasSelection}
secondaryBody={!hasSelection}
text03
>
{message}
</Text>
) : (
<Text mainUiBody={hasSelection} mainUiMuted={!hasSelection} text03>
{message}
</Text>
)}
{hasSelection && (
<div className="flex flex-row items-center w-fit flex-shrink-0 h-fit">
{onView && (
<Button
icon={SvgEye}
onClick={onView}
tooltip="View"
size={isSmall ? "sm" : "md"}
prominence="tertiary"
/>
)}
{onClear && (
<Button
icon={SvgXCircle}
onClick={onClear}
tooltip="Clear selection"
size={isSmall ? "sm" : "md"}
prominence="tertiary"
/>
)}
</div>
)}
</div>
);
}
interface SummaryLeftProps {
rangeStart: number;
rangeEnd: number;
totalItems: number;
isSmall: boolean;
}
function SummaryLeft({
rangeStart,
rangeEnd,
totalItems,
isSmall,
}: SummaryLeftProps) {
return (
<div className="flex flex-row gap-1 items-center w-fit h-fit px-1">
{isSmall ? (
<Text secondaryBody text03>
Showing{" "}
<Text as="span" secondaryMono text03>
{rangeStart}~{rangeEnd}
</Text>{" "}
of{" "}
<Text as="span" secondaryMono text03>
{totalItems}
</Text>
</Text>
) : (
<Text mainUiMuted text03>
Showing{" "}
<Text as="span" mainUiMono text03>
{rangeStart}~{rangeEnd}
</Text>{" "}
of{" "}
<Text as="span" mainUiMono text03>
{totalItems}
</Text>
</Text>
)}
</div>
);
}

View File

@@ -0,0 +1,393 @@
"use client";
import { Button } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { SvgChevronLeft, SvgChevronRight } from "@opal/icons";
type PaginationSize = "lg" | "md" | "sm";
/**
* Minimal page navigation showing `currentPage / totalPages` with prev/next arrows.
* Use when you only need simple forward/backward navigation.
*/
interface SimplePaginationProps {
type: "simple";
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** When `true`, displays the word "pages" after the page indicator. */
showUnits?: boolean;
/** When `false`, hides the page indicator between the prev/next arrows. Defaults to `true`. */
showPageIndicator?: boolean;
/** Controls button and text sizing. Defaults to `"lg"`. */
size?: PaginationSize;
className?: string;
}
/**
* Item-count pagination showing `currentItems of totalItems` with optional page
* controls and a "Go to" button. Use inside table footers that need to communicate
* how many items the user is viewing.
*/
interface CountPaginationProps {
type: "count";
/** Number of items displayed per page. Used to compute the visible range. */
pageSize: number;
/** Total number of items across all pages. */
totalItems: number;
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** When `false`, hides the page number between the prev/next arrows (arrows still visible). Defaults to `true`. */
showPageIndicator?: boolean;
/** When `true`, renders a "Go to" button. Requires `onGoTo`. */
showGoTo?: boolean;
/** Callback invoked when the "Go to" button is clicked. */
onGoTo?: () => void;
/** When `true`, displays the word "items" after the total count. */
showUnits?: boolean;
/** Controls button and text sizing. Defaults to `"lg"`. */
size?: PaginationSize;
className?: string;
}
/**
* Numbered page-list pagination with clickable page buttons and ellipsis
* truncation for large page counts. Does not support `"sm"` size.
*/
interface ListPaginationProps {
type: "list";
/** The 1-based current page number. */
currentPage: number;
/** Total number of pages. */
totalPages: number;
/** Called when the user navigates to a different page. */
onPageChange: (page: number) => void;
/** When `false`, hides the page buttons between the prev/next arrows. Defaults to `true`. */
showPageIndicator?: boolean;
/** Controls button and text sizing. Defaults to `"lg"`. Only `"lg"` and `"md"` are supported. */
size?: Exclude<PaginationSize, "sm">;
className?: string;
}
/**
* Discriminated union of all pagination variants.
* Use the `type` prop to select between `"simple"`, `"count"`, and `"list"`.
*/
export type PaginationProps =
| SimplePaginationProps
| CountPaginationProps
| ListPaginationProps;
function getPageNumbers(currentPage: number, totalPages: number) {
const pages: (number | string)[] = [];
const maxPagesToShow = 7;
if (totalPages <= maxPagesToShow) {
for (let i = 1; i <= totalPages; i++) {
pages.push(i);
}
} else {
pages.push(1);
let startPage = Math.max(2, currentPage - 1);
let endPage = Math.min(totalPages - 1, currentPage + 1);
if (currentPage <= 3) {
endPage = 5;
} else if (currentPage >= totalPages - 2) {
startPage = totalPages - 4;
}
if (startPage > 2) {
if (startPage === 3) {
pages.push(2);
} else {
pages.push("start-ellipsis");
}
}
for (let i = startPage; i <= endPage; i++) {
pages.push(i);
}
if (endPage < totalPages - 1) {
if (endPage === totalPages - 2) {
pages.push(totalPages - 1);
} else {
pages.push("end-ellipsis");
}
}
pages.push(totalPages);
}
return pages;
}
function sizedTextProps(isSmall: boolean, variant: "mono" | "muted") {
if (variant === "mono") {
return isSmall ? { secondaryMono: true } : { mainUiMono: true };
}
return isSmall ? { secondaryBody: true } : { mainUiMuted: true };
}
interface NavButtonsProps {
currentPage: number;
totalPages: number;
onPageChange: (page: number) => void;
size: PaginationSize;
children?: React.ReactNode;
}
function NavButtons({
currentPage,
totalPages,
onPageChange,
size,
children,
}: NavButtonsProps) {
return (
<>
<Button
icon={SvgChevronLeft}
onClick={() => onPageChange(currentPage - 1)}
disabled={currentPage <= 1}
size={size}
prominence="tertiary"
tooltip="Previous page"
/>
{children}
<Button
icon={SvgChevronRight}
onClick={() => onPageChange(currentPage + 1)}
disabled={currentPage >= totalPages}
size={size}
prominence="tertiary"
tooltip="Next page"
/>
</>
);
}
/**
* Table pagination component with three variants: `simple`, `count`, and `list`.
* Pass the `type` prop to select the variant, and the component will render the
* appropriate UI.
*/
export default function Pagination(props: PaginationProps) {
const normalized = { ...props, totalPages: Math.max(1, props.totalPages) };
switch (normalized.type) {
case "simple":
return <SimplePaginationInner {...normalized} />;
case "count":
return <CountPaginationInner {...normalized} />;
case "list":
return <ListPaginationInner {...normalized} />;
}
}
function SimplePaginationInner({
currentPage,
totalPages,
onPageChange,
showUnits,
showPageIndicator = true,
size = "lg",
className,
}: SimplePaginationProps) {
const isSmall = size === "sm";
return (
<div className={cn("flex items-center gap-1", className)}>
<NavButtons
currentPage={currentPage}
totalPages={totalPages}
onPageChange={onPageChange}
size={size}
>
{showPageIndicator && (
<>
<Text {...sizedTextProps(isSmall, "mono")} text03>
{currentPage}
<Text as="span" {...sizedTextProps(isSmall, "muted")} text03>
/
</Text>
{totalPages}
</Text>
{showUnits && (
<Text {...sizedTextProps(isSmall, "muted")} text03>
pages
</Text>
)}
</>
)}
</NavButtons>
</div>
);
}
function CountPaginationInner({
pageSize,
totalItems,
currentPage,
totalPages,
onPageChange,
showPageIndicator = true,
showGoTo,
onGoTo,
showUnits,
size = "lg",
className,
}: CountPaginationProps) {
const isSmall = size === "sm";
const rangeStart = totalItems === 0 ? 0 : (currentPage - 1) * pageSize + 1;
const rangeEnd = Math.min(currentPage * pageSize, totalItems);
const currentItems = `${rangeStart}~${rangeEnd}`;
return (
<div className={cn("flex items-center gap-1", className)}>
<Text {...sizedTextProps(isSmall, "mono")} text03>
{currentItems}
</Text>
<Text {...sizedTextProps(isSmall, "muted")} text03>
of
</Text>
<Text {...sizedTextProps(isSmall, "mono")} text03>
{totalItems}
</Text>
{showUnits && (
<Text {...sizedTextProps(isSmall, "muted")} text03>
items
</Text>
)}
<NavButtons
currentPage={currentPage}
totalPages={totalPages}
onPageChange={onPageChange}
size={size}
>
{showPageIndicator && (
<Text {...sizedTextProps(isSmall, "mono")} text03>
{currentPage}
</Text>
)}
</NavButtons>
{showGoTo && onGoTo && (
<Button onClick={onGoTo} size={size} prominence="tertiary">
Go to
</Button>
)}
</div>
);
}
interface PageNumberIconProps {
className?: string;
pageNum: number;
isActive: boolean;
isLarge: boolean;
}
function PageNumberIcon({
className: iconClassName,
pageNum,
isActive,
isLarge,
}: PageNumberIconProps) {
return (
<div className={cn(iconClassName, "flex flex-col justify-center")}>
{isLarge ? (
<Text
mainUiBody={isActive}
mainUiMuted={!isActive}
text04={isActive}
text02={!isActive}
>
{pageNum}
</Text>
) : (
<Text
secondaryAction={isActive}
secondaryBody={!isActive}
text04={isActive}
text02={!isActive}
>
{pageNum}
</Text>
)}
</div>
);
}
function ListPaginationInner({
currentPage,
totalPages,
onPageChange,
showPageIndicator = true,
size = "lg",
className,
}: ListPaginationProps) {
const pageNumbers = getPageNumbers(currentPage, totalPages);
const isLarge = size === "lg";
return (
<div className={cn("flex items-center gap-1", className)}>
<NavButtons
currentPage={currentPage}
totalPages={totalPages}
onPageChange={onPageChange}
size={size}
>
{showPageIndicator && (
<div className="flex items-center">
{pageNumbers.map((page) => {
if (typeof page === "string") {
return (
<Text
key={page}
mainUiMuted={isLarge}
secondaryBody={!isLarge}
text03
>
...
</Text>
);
}
const pageNum = page as number;
const isActive = pageNum === currentPage;
return (
<Button
key={pageNum}
onClick={() => onPageChange(pageNum)}
size={size}
prominence="tertiary"
transient={isActive}
icon={({ className: iconClassName }) => (
<PageNumberIcon
className={iconClassName}
pageNum={pageNum}
isActive={isActive}
isLarge={isLarge}
/>
)}
/>
);
})}
</div>
)}
</NavButtons>
</div>
);
}

View File

@@ -0,0 +1,317 @@
# DataTable
Config-driven table built on [TanStack Table](https://tanstack.com/table). Handles column sizing (weight-based proportional distribution), drag-and-drop row reordering, pagination, row selection, column visibility, and sorting out of the box.
## Quick Start
```tsx
import DataTable from "@/refresh-components/table/DataTable";
import { createTableColumns } from "@/refresh-components/table/columns";
interface Person {
name: string;
email: string;
role: string;
}
// Define columns at module scope (stable reference, no re-renders)
const tc = createTableColumns<Person>();
const columns = [
tc.qualifier(),
tc.column("name", { header: "Name", weight: 30, minWidth: 120 }),
tc.column("email", { header: "Email", weight: 40, minWidth: 150 }),
tc.column("role", { header: "Role", weight: 30, minWidth: 80 }),
tc.actions(),
];
function PeopleTable({ data }: { data: Person[] }) {
return (
<DataTable
data={data}
columns={columns}
pageSize={10}
footer={{ mode: "selection" }}
/>
);
}
```
## Column Builder API
`createTableColumns<TData>()` returns a typed builder with four methods. Each returns an `OnyxColumnDef<TData>` that you pass to the `columns` prop.
### `tc.qualifier(config?)`
Leading column for avatars, icons, images, or checkboxes.
| Option | Type | Default | Description |
|---|---|---|---|
| `content` | `"simple" \| "icon" \| "image" \| "avatar-icon" \| "avatar-user"` | `"simple"` | Body row content type |
| `headerContentType` | same as `content` | `"simple"` | Header row content type |
| `getInitials` | `(row: TData) => string` | - | Extract initials (for `"avatar-user"`) |
| `getIcon` | `(row: TData) => IconFunctionComponent` | - | Extract icon (for `"icon"` / `"avatar-icon"`) |
| `getImageSrc` | `(row: TData) => string` | - | Extract image src (for `"image"`) |
| `selectable` | `boolean` | `true` | Show selection checkboxes |
| `header` | `boolean` | `true` | Render qualifier content in the header |
Width is fixed: 56px at `"regular"` size, 40px at `"small"`.
```ts
tc.qualifier({
content: "avatar-user",
getInitials: (row) => row.initials,
})
```
### `tc.column(accessor, config)`
Data column with sorting, resizing, and hiding. The `accessor` is a type-safe deep key into `TData`.
| Option | Type | Default | Description |
|---|---|---|---|
| `header` | `string` | **required** | Column header label |
| `cell` | `(value: TValue, row: TData) => ReactNode` | renders value as string | Custom cell renderer |
| `enableSorting` | `boolean` | `true` | Allow sorting |
| `enableResizing` | `boolean` | `true` | Allow column resize |
| `enableHiding` | `boolean` | `true` | Allow hiding via actions popover |
| `icon` | `(sorted: SortDirection) => IconFunctionComponent` | - | Override the sort indicator icon |
| `weight` | `number` | `20` | Proportional width weight |
| `minWidth` | `number` | `50` | Minimum width in pixels |
```ts
tc.column("email", {
header: "Email",
weight: 28,
minWidth: 150,
cell: (value) => <Content sizePreset="main-ui" variant="body" title={value} prominence="muted" />,
})
```
### `tc.displayColumn(config)`
Non-accessor column for custom content (e.g. computed values, action buttons per row).
| Option | Type | Default | Description |
|---|---|---|---|
| `id` | `string` | **required** | Unique column ID |
| `header` | `string` | - | Optional header label |
| `cell` | `(row: TData) => ReactNode` | **required** | Cell renderer |
| `width` | `ColumnWidth` | **required** | `{ weight, minWidth? }` or `{ fixed }` |
| `enableHiding` | `boolean` | `true` | Allow hiding |
```ts
tc.displayColumn({
id: "fullName",
header: "Full Name",
cell: (row) => `${row.firstName} ${row.lastName}`,
width: { weight: 25, minWidth: 100 },
})
```
### `tc.actions(config?)`
Fixed-width column rendered at the trailing edge. Houses column visibility and sorting popovers in the header.
| Option | Type | Default | Description |
|---|---|---|---|
| `showColumnVisibility` | `boolean` | `true` | Show the column visibility popover |
| `showSorting` | `boolean` | `true` | Show the sorting popover |
| `sortingFooterText` | `string` | - | Footer text inside the sorting popover |
Width is fixed: 88px at `"regular"`, 20px at `"small"`.
```ts
tc.actions({
sortingFooterText: "Everyone will see agents in this order.",
})
```
## DataTable Props
`DataTableProps<TData>`:
| Prop | Type | Default | Description |
|---|---|---|---|
| `data` | `TData[]` | **required** | Row data |
| `columns` | `OnyxColumnDef<TData>[]` | **required** | Columns from `createTableColumns()` |
| `pageSize` | `number` | `10` (with footer) or `data.length` (without) | Rows per page. `Infinity` disables pagination |
| `initialSorting` | `SortingState` | `[]` | TanStack sorting state |
| `initialColumnVisibility` | `VisibilityState` | `{}` | Map of column ID to `false` to hide initially |
| `draggable` | `DataTableDraggableConfig<TData>` | - | Enable drag-and-drop (see below) |
| `footer` | `DataTableFooterConfig` | - | Footer mode (see below) |
| `size` | `"regular" \| "small"` | `"regular"` | Table density variant |
| `onRowClick` | `(row: TData) => void` | toggles selection | Called on row click, replaces default selection toggle |
| `height` | `number \| string` | - | Max height for scrollable body (header stays pinned). `300` or `"50vh"` |
| `headerBackground` | `string` | - | CSS color for the sticky header (prevents content showing through) |
## Footer Config
The `footer` prop accepts a discriminated union on `mode`.
### Selection mode
For tables with selectable rows. Shows a selection message + count pagination.
```ts
footer={{
mode: "selection",
multiSelect: true, // default true
onView: () => { ... }, // optional "View" button
onClear: () => { ... }, // optional "Clear" button (falls back to default clearSelection)
}}
```
### Summary mode
For read-only tables. Shows "Showing X~Y of Z" + list pagination.
```ts
footer={{ mode: "summary" }}
```
## Draggable Config
Enable drag-and-drop row reordering. DnD is automatically disabled when column sorting is active.
```ts
<DataTable
data={items}
columns={columns}
draggable={{
getRowId: (row) => row.id,
onReorder: (ids, changedOrders) => {
// ids: new ordered array of all row IDs
// changedOrders: { [id]: newIndex } for rows that moved
setItems(ids.map((id) => items.find((r) => r.id === id)!));
},
}}
/>
```
| Option | Type | Description |
|---|---|---|
| `getRowId` | `(row: TData) => string` | Extract a unique string ID from each row |
| `onReorder` | `(ids: string[], changedOrders: Record<string, number>) => void \| Promise<void>` | Called after a successful reorder |
## Sizing
The `size` prop (`"regular"` or `"small"`) affects:
- Qualifier column width (56px vs 40px)
- Actions column width (88px vs 20px)
- Footer text styles and pagination size
- All child components via `TableSizeContext`
Column widths can be responsive to size using a function:
```ts
// In types.ts, width accepts:
width: ColumnWidth | ((size: TableSize) => ColumnWidth)
// Example (this is what qualifier/actions use internally):
width: (size) => size === "small" ? { fixed: 40 } : { fixed: 56 }
```
### Width system
Data columns use **weight-based proportional distribution**. A column with `weight: 40` gets twice the space of one with `weight: 20`. When the container is narrower than the sum of `minWidth` values, columns clamp to their minimums.
Fixed columns (`{ fixed: N }`) take exactly N pixels and don't participate in proportional distribution.
Resizing uses **splitter semantics**: dragging a column border grows that column and shrinks its neighbor by the same amount, keeping total width constant.
## Advanced Examples
### Scrollable table with pinned header
```tsx
<DataTable
data={allRows}
columns={columns}
height={300}
headerBackground="var(--background-tint-00)"
/>
```
### Hidden columns on load
```tsx
<DataTable
data={data}
columns={columns}
initialColumnVisibility={{ department: false, joinDate: false }}
footer={{ mode: "selection" }}
/>
```
### Icon-based data column
```tsx
const STATUS_ICONS = {
active: SvgCheckCircle,
pending: SvgClock,
inactive: SvgAlertCircle,
} as const;
tc.column("status", {
header: "Status",
weight: 14,
minWidth: 80,
cell: (value) => (
<Content
sizePreset="main-ui"
variant="body"
icon={STATUS_ICONS[value]}
title={value.charAt(0).toUpperCase() + value.slice(1)}
/>
),
})
```
### Non-selectable qualifier with icons
```ts
tc.qualifier({
content: "icon",
getIcon: (row) => row.icon,
selectable: false,
header: false,
})
```
### Small variant in a bordered container
```tsx
<div className="border border-border-01 rounded-lg overflow-hidden">
<DataTable
data={data}
columns={columns}
size="small"
pageSize={10}
footer={{ mode: "selection" }}
/>
</div>
```
### Custom row click handler
```tsx
<DataTable
data={data}
columns={columns}
onRowClick={(row) => router.push(`/users/${row.id}`)}
/>
```
## Source Files
| File | Purpose |
|---|---|
| `DataTable.tsx` | Main component |
| `columns.ts` | `createTableColumns` builder |
| `types.ts` | All TypeScript interfaces |
| `hooks/useDataTable.ts` | TanStack table wrapper hook |
| `hooks/useColumnWidths.ts` | Weight-based width system |
| `hooks/useDraggableRows.ts` | DnD hook (`@dnd-kit`) |
| `Footer.tsx` | Selection / Summary footer modes |
| `TableSizeContext.tsx` | Size context provider |

View File

@@ -0,0 +1,181 @@
"use client";
import { useState } from "react";
import {
type Table,
type ColumnDef,
type RowData,
type SortingState,
} from "@tanstack/react-table";
import { Button } from "@opal/components";
import { SvgArrowUpDown, SvgSortOrder, SvgCheck } from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import Divider from "@/refresh-components/Divider";
import LineItem from "@/refresh-components/buttons/LineItem";
import Text from "@/refresh-components/texts/Text";
// ---------------------------------------------------------------------------
// Popover UI
// ---------------------------------------------------------------------------
interface SortingPopoverProps<TData extends RowData = RowData> {
table: Table<TData>;
sorting: SortingState;
size?: "regular" | "small";
footerText?: string;
ascendingLabel?: string;
descendingLabel?: string;
}
function SortingPopover<TData extends RowData>({
table,
sorting,
size = "regular",
footerText,
ascendingLabel = "Ascending",
descendingLabel = "Descending",
}: SortingPopoverProps<TData>) {
const [open, setOpen] = useState(false);
const sortableColumns = table
.getAllLeafColumns()
.filter((col) => col.getCanSort());
const currentSort = sorting[0] ?? null;
return (
<Popover open={open} onOpenChange={setOpen}>
<Popover.Trigger asChild>
<Button
icon={currentSort === null ? SvgArrowUpDown : SvgSortOrder}
transient={open}
size={size === "small" ? "sm" : "md"}
prominence="internal"
tooltip="Sort"
/>
</Popover.Trigger>
<Popover.Content width="lg" align="end" side="bottom">
<Popover.Menu
footer={
footerText ? (
<div className="px-2 py-1">
<Text secondaryBody text03>
{footerText}
</Text>
</div>
) : undefined
}
>
<Divider showTitle text="Sort by" />
<LineItem
selected={currentSort === null}
emphasized
rightChildren={
currentSort === null ? <SvgCheck size={16} /> : undefined
}
onClick={() => {
table.resetSorting();
}}
>
Manual Ordering
</LineItem>
{sortableColumns.map((column) => {
const isSorted = currentSort?.id === column.id;
const label =
typeof column.columnDef.header === "string"
? column.columnDef.header
: column.id;
return (
<LineItem
key={column.id}
selected={isSorted}
emphasized
rightChildren={isSorted ? <SvgCheck size={16} /> : undefined}
onClick={() => {
if (isSorted) {
table.resetSorting();
return;
}
column.toggleSorting(false);
}}
>
{label}
</LineItem>
);
})}
{currentSort !== null && (
<>
<Divider showTitle text="Sorting Order" />
<LineItem
selected={!currentSort.desc}
emphasized
rightChildren={
!currentSort.desc ? <SvgCheck size={16} /> : undefined
}
onClick={() => {
table.setSorting([{ id: currentSort.id, desc: false }]);
}}
>
{ascendingLabel}
</LineItem>
<LineItem
selected={currentSort.desc}
emphasized
rightChildren={
currentSort.desc ? <SvgCheck size={16} /> : undefined
}
onClick={() => {
table.setSorting([{ id: currentSort.id, desc: true }]);
}}
>
{descendingLabel}
</LineItem>
</>
)}
</Popover.Menu>
</Popover.Content>
</Popover>
);
}
// ---------------------------------------------------------------------------
// Column definition factory
// ---------------------------------------------------------------------------
interface CreateSortingColumnOptions {
size?: "regular" | "small";
footerText?: string;
ascendingLabel?: string;
descendingLabel?: string;
}
function createSortingColumn<TData>(
options?: CreateSortingColumnOptions
): ColumnDef<TData, unknown> {
return {
id: "__sorting",
size: 44,
enableHiding: false,
enableSorting: false,
enableResizing: false,
header: ({ table }) => (
<SortingPopover
table={table}
sorting={table.getState().sorting}
size={options?.size}
footerText={options?.footerText}
ascendingLabel={options?.ascendingLabel}
descendingLabel={options?.descendingLabel}
/>
),
cell: () => null,
};
}
export { SortingPopover, createSortingColumn };

View File

@@ -830,12 +830,8 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
ref={chatInputBarRef}
deepResearchEnabled={deepResearchEnabled}
toggleDeepResearch={toggleDeepResearch}
toggleDocumentSidebar={toggleDocumentSidebar}
filterManager={filterManager}
llmManager={llmManager}
removeDocs={() => setSelectedDocuments([])}
retrievalEnabled={retrievalEnabled}
selectedDocuments={selectedDocuments}
initialMessage={
searchParams?.get(SEARCH_PARAM_NAMES.USER_PROMPT) ||
""

View File

@@ -173,19 +173,21 @@ export function FileCard({
removeFile && doneUploading ? () => removeFile(file.id) : undefined
}
>
<div className="max-w-[12rem]">
<div className="min-w-0 max-w-[12rem]">
<Interactive.Container border heightVariant="fit">
<AttachmentItemLayout
icon={isProcessing ? SimpleLoader : SvgFileText}
title={file.name}
description={
isProcessing
? file.status === UserFileStatus.UPLOADING
? "Uploading..."
: "Processing..."
: typeLabel
}
/>
<div className="[&_.opal-content-md-body]:min-w-0 [&_.opal-content-md-title]:break-all">
<AttachmentItemLayout
icon={isProcessing ? SimpleLoader : SvgFileText}
title={file.name}
description={
isProcessing
? file.status === UserFileStatus.UPLOADING
? "Uploading..."
: "Processing..."
: typeLabel
}
/>
</div>
<Spacer horizontal rem={0.5} />
</Interactive.Container>
</div>

View File

@@ -16,16 +16,18 @@ import { FilterManager, LlmManager, useFederatedConnectors } from "@/lib/hooks";
import usePromptShortcuts from "@/hooks/usePromptShortcuts";
import useFilter from "@/hooks/useFilter";
import useCCPairs from "@/hooks/useCCPairs";
import { OnyxDocument, MinimalOnyxDocument } from "@/lib/search/interfaces";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import { ChatState } from "@/app/app/interfaces";
import { useForcedTools } from "@/lib/hooks/useForcedTools";
import { useAppMode } from "@/providers/AppModeProvider";
import useAppFocus from "@/hooks/useAppFocus";
import { getFormattedDateRangeString } from "@/lib/dateUtils";
import { truncateString, cn, isImageFile } from "@/lib/utils";
import { cn, isImageFile } from "@/lib/utils";
import { Disabled } from "@/refresh-components/Disabled";
import { useUser } from "@/providers/UserProvider";
import { SettingsContext } from "@/providers/SettingsProvider";
import {
SettingsContext,
useVectorDbEnabled,
} from "@/providers/SettingsProvider";
import { useProjectsContext } from "@/providers/ProjectsContext";
import { FileCard } from "@/sections/cards/FileCard";
import {
@@ -40,9 +42,6 @@ import {
} from "@/app/app/services/actionUtils";
import {
SvgArrowUp,
SvgCalendar,
SvgFiles,
SvgFileText,
SvgGlobe,
SvgHourglass,
SvgPlus,
@@ -51,64 +50,22 @@ import {
SvgStop,
SvgX,
} from "@opal/icons";
import { Button, OpenButton } from "@opal/components";
import { Button } from "@opal/components";
import Popover from "@/refresh-components/Popover";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { useQueryController } from "@/providers/QueryControllerProvider";
import { Section } from "@/layouts/general-layouts";
import Spacer from "@/refresh-components/Spacer";
const LINE_HEIGHT = 24;
const MIN_INPUT_HEIGHT = 44;
const MAX_INPUT_HEIGHT = 200;
export interface SourceChipProps {
icon?: React.ReactNode;
title: string;
onRemove?: () => void;
onClick?: () => void;
truncateTitle?: boolean;
}
export function SourceChip({
icon,
title,
onRemove,
onClick,
truncateTitle = true,
}: SourceChipProps) {
return (
<div
onClick={onClick ? onClick : undefined}
className={cn(
"flex-none flex items-center px-1 bg-background-neutral-01 text-xs text-text-04 border border-border-01 rounded-08 box-border gap-x-1 h-6",
onClick && "cursor-pointer"
)}
>
{icon}
{truncateTitle ? truncateString(title, 20) : title}
{onRemove && (
<SvgX
size={12}
className="text-text-01 ml-auto cursor-pointer"
onClick={(e: React.MouseEvent<SVGSVGElement>) => {
e.stopPropagation();
onRemove();
}}
/>
)}
</div>
);
}
export interface AppInputBarHandle {
reset: () => void;
focus: () => void;
}
export interface AppInputBarProps {
removeDocs: () => void;
selectedDocuments: OnyxDocument[];
initialMessage?: string;
stopGenerating: () => void;
onSubmit: (message: string) => void;
@@ -120,10 +77,8 @@ export interface AppInputBarProps {
// agents
selectedAgent: MinimalPersonaSnapshot | undefined;
toggleDocumentSidebar: () => void;
handleFileUpload: (files: File[]) => void;
filterManager: FilterManager;
retrievalEnabled: boolean;
deepResearchEnabled: boolean;
setPresentingDocument?: (document: MinimalOnyxDocument) => void;
toggleDeepResearch: () => void;
@@ -137,18 +92,13 @@ export interface AppInputBarProps {
const AppInputBar = React.memo(
({
retrievalEnabled,
removeDocs,
toggleDocumentSidebar,
filterManager,
selectedDocuments,
initialMessage = "",
stopGenerating,
onSubmit,
chatState,
currentSessionFileTokenCount,
availableContextTokens,
// agents
selectedAgent,
handleFileUpload,
@@ -165,6 +115,9 @@ const AppInputBar = React.memo(
// Internal message state - kept local to avoid parent re-renders on every keystroke
const [message, setMessage] = useState(initialMessage);
const textAreaRef = useRef<HTMLTextAreaElement>(null);
const textAreaWrapperRef = useRef<HTMLDivElement>(null);
const filesWrapperRef = useRef<HTMLDivElement>(null);
const filesContentRef = useRef<HTMLDivElement>(null);
const containerRef = useRef<HTMLDivElement>(null);
const { user } = useUser();
const { isClassifying, classification } = useQueryController();
@@ -178,6 +131,16 @@ const AppInputBar = React.memo(
textAreaRef.current?.focus();
},
}));
// Sync non-empty prop changes to internal state (e.g. NRFPage reads URL params
// after mount). Intentionally skips empty strings — clearing is handled via the
// imperative ref.reset() method, not by passing initialMessage="".
useEffect(() => {
if (initialMessage) {
setMessage(initialMessage);
}
}, [initialMessage]);
const { appMode } = useAppMode();
const appFocus = useAppFocus();
const isSearchMode =
@@ -227,46 +190,39 @@ const AppInputBar = React.memo(
const combinedSettings = useContext(SettingsContext);
// Track previous message to detect when lines might decrease
const prevMessageRef = useRef("");
// Auto-resize textarea based on content
// TODO(@raunakab): Replace this useEffect with CSS `field-sizing: content` once
// Firefox ships it unflagged (currently behind `layout.css.field-sizing.enabled`).
// Auto-resize textarea based on content (chat mode only).
// Reset to min-height first so scrollHeight reflects actual content size,
// then clamp between min and max. This handles both growing and shrinking.
useEffect(() => {
if (isSearchMode) return;
const wrapper = textAreaWrapperRef.current;
const textarea = textAreaRef.current;
if (textarea) {
const prevLineCount = (prevMessageRef.current.match(/\n/g) || [])
.length;
const currLineCount = (message.match(/\n/g) || []).length;
const lineRemoved = currLineCount < prevLineCount;
prevMessageRef.current = message;
if (!wrapper || !textarea) return;
if (message.length === 0) {
textarea.style.height = `${MIN_INPUT_HEIGHT}px`;
return;
} else if (lineRemoved) {
const linesRemoved = prevLineCount - currLineCount;
textarea.style.height = `${Math.max(
MIN_INPUT_HEIGHT,
Math.min(
textarea.scrollHeight - LINE_HEIGHT * linesRemoved,
MAX_INPUT_HEIGHT
)
)}px`;
} else {
textarea.style.height = `${Math.min(
textarea.scrollHeight,
MAX_INPUT_HEIGHT
)}px`;
}
}
wrapper.style.height = `${MIN_INPUT_HEIGHT}px`;
wrapper.style.height = `${Math.min(
Math.max(textarea.scrollHeight, MIN_INPUT_HEIGHT),
MAX_INPUT_HEIGHT
)}px`;
}, [message, isSearchMode]);
// Animate attached files wrapper to its content height so CSS transitions
// can interpolate between concrete pixel values (0px ↔ Npx).
const showFiles = !isSearchMode && currentMessageFiles.length > 0;
useEffect(() => {
if (initialMessage) {
setMessage(initialMessage);
const wrapper = filesWrapperRef.current;
const content = filesContentRef.current;
if (!wrapper || !content) return;
if (showFiles) {
// Measure the inner content's actual height, then add padding (p-1 = 8px total)
const PADDING = 8;
wrapper.style.height = `${content.offsetHeight + PADDING}px`;
} else {
wrapper.style.height = "0px";
}
}, [initialMessage]);
}, [showFiles, currentMessageFiles]);
function handlePaste(event: React.ClipboardEvent) {
const items = event.clipboardData?.items;
@@ -294,8 +250,7 @@ const AppInputBar = React.memo(
);
const { activePromptShortcuts } = usePromptShortcuts();
const vectorDbEnabled =
combinedSettings?.settings.vector_db_enabled !== false;
const vectorDbEnabled = useVectorDbEnabled();
const { ccPairs, isLoading: ccPairsLoading } = useCCPairs(vectorDbEnabled);
const { data: federatedConnectorsData, isLoading: federatedLoading } =
useFederatedConnectors();
@@ -412,7 +367,9 @@ const AppInputBar = React.memo(
combinedSettings?.settings?.deep_research_enabled,
]);
function handleKeyDown(e: React.KeyboardEvent<HTMLTextAreaElement>) {
function handleKeyDownForPromptShortcuts(
e: React.KeyboardEvent<HTMLTextAreaElement>
) {
if (!user?.preferences?.shortcut_enabled || !showPrompts) return;
if (e.key === "Enter") {
@@ -447,6 +404,171 @@ const AppInputBar = React.memo(
}
}
const chatControls = (
<div
{...(isSearchMode ? { inert: true } : {})}
className={cn(
"flex justify-between items-center w-full",
isSearchMode
? "opacity-0 p-0 h-0 overflow-hidden pointer-events-none"
: "opacity-100 p-1 h-[2.75rem] pointer-events-auto",
"transition-all duration-150"
)}
>
{/* Bottom left controls */}
<div className="flex flex-row items-center">
{/* (+) button - always visible */}
<FilePickerPopover
onFileClick={handleFileClick}
onPickRecent={(file: ProjectFile) => {
// Check if file with same ID already exists
if (
!currentMessageFiles.some(
(existingFile) => existingFile.file_id === file.file_id
)
) {
setCurrentMessageFiles((prev) => [...prev, file]);
}
}}
onUnpickRecent={(file: ProjectFile) => {
setCurrentMessageFiles((prev) =>
prev.filter(
(existingFile) => existingFile.file_id !== file.file_id
)
);
}}
handleUploadChange={handleUploadChange}
trigger={(open) => (
<Button
icon={SvgPlusCircle}
tooltip="Attach Files"
transient={open}
disabled={disabled}
prominence="tertiary"
/>
)}
selectedFileIds={currentMessageFiles.map((f) => f.id)}
/>
{/* Controls that load in when data is ready */}
<div
data-testid="actions-container"
className={cn(
"flex flex-row items-center",
controlsLoading && "invisible"
)}
>
{selectedAgent && selectedAgent.tools.length > 0 && (
<ActionsPopover
selectedAgent={selectedAgent}
filterManager={filterManager}
availableSources={memoizedAvailableSources}
disabled={disabled}
/>
)}
{onToggleTabReading ? (
<Button
icon={SvgGlobe}
onClick={onToggleTabReading}
variant="select"
selected={tabReadingEnabled}
foldable={!tabReadingEnabled}
disabled={disabled}
>
{tabReadingEnabled
? currentTabUrl
? (() => {
try {
return new URL(currentTabUrl).hostname;
} catch {
return currentTabUrl;
}
})()
: "Reading tab..."
: "Read this tab"}
</Button>
) : (
showDeepResearch && (
<Button
icon={SvgHourglass}
onClick={toggleDeepResearch}
variant="select"
selected={deepResearchEnabled}
foldable={!deepResearchEnabled}
disabled={disabled}
>
Deep Research
</Button>
)
)}
{selectedAgent &&
forcedToolIds.length > 0 &&
forcedToolIds.map((toolId) => {
const tool = selectedAgent.tools.find(
(tool) => tool.id === toolId
);
if (!tool) {
return null;
}
return (
<Button
key={toolId}
icon={getIconForAction(tool)}
onClick={() => {
setForcedToolIds(
forcedToolIds.filter((id) => id !== toolId)
);
}}
variant="select"
selected
disabled={disabled}
>
{tool.display_name}
</Button>
);
})}
</div>
</div>
{/* Bottom right controls */}
<div className="flex flex-row items-center gap-1">
<div
data-testid="AppInputBar/llm-popover-trigger"
className={cn(controlsLoading && "invisible")}
>
<LLMPopover
llmManager={llmManager}
requiresImageInput={hasImageFiles}
disabled={disabled}
/>
</div>
<Button
id="onyx-chat-input-send-button"
icon={
isClassifying
? SimpleLoader
: chatState === "input"
? SvgArrowUp
: SvgStop
}
disabled={
(chatState === "input" && !message) ||
hasUploadingFiles ||
isClassifying
}
onClick={() => {
if (chatState == "streaming") {
stopGenerating();
} else if (message) {
onSubmit(message);
}
}}
/>
</div>
</div>
);
return (
<Disabled disabled={disabled} allowClick>
<div
@@ -467,8 +589,17 @@ const AppInputBar = React.memo(
)}
>
{/* Attached Files */}
{currentMessageFiles.length > 0 && (
<div className="p-2 rounded-t-16 flex flex-wrap gap-1">
<div
ref={filesWrapperRef}
{...(!showFiles ? { inert: true } : {})}
className={cn(
"transition-all duration-150",
showFiles
? "opacity-100 p-1"
: "opacity-0 p-0 overflow-hidden pointer-events-none"
)}
>
<div ref={filesContentRef} className="flex flex-wrap gap-1">
{currentMessageFiles.map((file) => (
<FileCard
key={file.id}
@@ -480,76 +611,61 @@ const AppInputBar = React.memo(
/>
))}
</div>
)}
</div>
{/* Input area */}
<div
className={cn(
"flex flex-row items-center w-full",
isSearchMode && "p-1"
)}
>
<div className="flex flex-row items-center w-full">
<Popover
open={user?.preferences?.shortcut_enabled && showPrompts}
onOpenChange={setShowPrompts}
>
<Popover.Anchor asChild>
<textarea
onPaste={handlePaste}
onKeyDownCapture={handleKeyDown}
onChange={handleInputChange}
ref={textAreaRef}
id="onyx-chat-input-textarea"
className={cn(
"w-full",
"outline-none",
"bg-transparent",
"resize-none",
"placeholder:text-text-03",
"whitespace-pre-wrap",
"break-word",
"overscroll-contain",
"px-3",
isSearchMode
? "h-[40px] py-2.5 overflow-hidden"
: [
"h-[44px]", // Fixed initial height to prevent flash - useEffect will adjust as needed
"overflow-y-auto",
"pb-2",
"pt-3",
]
)}
autoFocus
style={{ scrollbarWidth: "thin" }}
role="textarea"
aria-multiline
placeholder={
isSearchMode
? "Search connected sources"
: "How can I help you today"
}
value={message}
onKeyDown={(event) => {
if (
event.key === "Enter" &&
!showPrompts &&
!event.shiftKey &&
!(event.nativeEvent as any).isComposing
) {
event.preventDefault();
if (
message &&
!disabled &&
!isClassifying &&
!hasUploadingFiles
) {
onSubmit(message);
}
<div
ref={textAreaWrapperRef}
className="px-3 py-2 flex-1 flex h-[2.75rem]"
>
<textarea
id="onyx-chat-input-textarea"
role="textarea"
ref={textAreaRef}
onPaste={handlePaste}
onKeyDownCapture={handleKeyDownForPromptShortcuts}
onChange={handleInputChange}
className={cn(
"p-[2px] w-full h-full outline-none bg-transparent resize-none placeholder:text-text-03 whitespace-pre-wrap break-words",
"overflow-y-auto"
)}
autoFocus
rows={1}
style={{ scrollbarWidth: "thin" }}
aria-multiline={true}
placeholder={
isSearchMode
? "Search connected sources"
: "How can I help you today?"
}
}}
suppressContentEditableWarning={true}
disabled={disabled}
/>
value={message}
onKeyDown={(event) => {
if (
event.key === "Enter" &&
!showPrompts &&
!event.shiftKey &&
!(event.nativeEvent as any).isComposing
) {
event.preventDefault();
if (
message &&
!disabled &&
!isClassifying &&
!hasUploadingFiles
) {
onSubmit(message);
}
}
}}
suppressContentEditableWarning={true}
disabled={disabled}
/>
</div>
</Popover.Anchor>
<Popover.Content
@@ -616,214 +732,7 @@ const AppInputBar = React.memo(
)}
</div>
{/* Source chips */}
{(selectedDocuments.length > 0 ||
filterManager.timeRange ||
filterManager.selectedDocumentSets.length > 0) && (
<div className="flex gap-x-.5 px-2">
<div className="flex gap-x-1 px-2 overflow-visible overflow-x-scroll items-end miniscroll">
{filterManager.timeRange && (
<SourceChip
truncateTitle={false}
key="time-range"
icon={<SvgCalendar size={12} />}
title={`${getFormattedDateRangeString(
filterManager.timeRange.from,
filterManager.timeRange.to
)}`}
onRemove={() => {
filterManager.setTimeRange(null);
}}
/>
)}
{filterManager.selectedDocumentSets.length > 0 &&
filterManager.selectedDocumentSets.map((docSet, index) => (
<SourceChip
key={`doc-set-${index}`}
icon={<SvgFiles size={16} />}
title={docSet}
onRemove={() => {
filterManager.setSelectedDocumentSets(
filterManager.selectedDocumentSets.filter(
(ds) => ds !== docSet
)
);
}}
/>
))}
{selectedDocuments.length > 0 && (
<SourceChip
key="selected-documents"
onClick={() => {
toggleDocumentSidebar();
}}
icon={<SvgFileText size={16} />}
title={`${selectedDocuments.length} selected`}
onRemove={removeDocs}
/>
)}
</div>
</div>
)}
{!isSearchMode && (
<div className="flex justify-between items-center w-full p-1 min-h-[40px]">
{/* Bottom left controls */}
<div className="flex flex-row items-center">
{/* (+) button - always visible */}
<FilePickerPopover
onFileClick={handleFileClick}
onPickRecent={(file: ProjectFile) => {
// Check if file with same ID already exists
if (
!currentMessageFiles.some(
(existingFile) => existingFile.file_id === file.file_id
)
) {
setCurrentMessageFiles((prev) => [...prev, file]);
}
}}
onUnpickRecent={(file: ProjectFile) => {
setCurrentMessageFiles((prev) =>
prev.filter(
(existingFile) => existingFile.file_id !== file.file_id
)
);
}}
handleUploadChange={handleUploadChange}
trigger={(open) => (
<Button
icon={SvgPlusCircle}
tooltip="Attach Files"
transient={open}
disabled={disabled}
prominence="tertiary"
/>
)}
selectedFileIds={currentMessageFiles.map((f) => f.id)}
/>
{/* Controls that load in when data is ready */}
<div
data-testid="actions-container"
className={cn(
"flex flex-row items-center",
controlsLoading && "invisible"
)}
>
{selectedAgent && selectedAgent.tools.length > 0 && (
<ActionsPopover
selectedAgent={selectedAgent}
filterManager={filterManager}
availableSources={memoizedAvailableSources}
disabled={disabled}
/>
)}
{onToggleTabReading ? (
<Button
icon={SvgGlobe}
onClick={onToggleTabReading}
variant="select"
selected={tabReadingEnabled}
foldable={!tabReadingEnabled}
disabled={disabled}
>
{tabReadingEnabled
? currentTabUrl
? (() => {
try {
return new URL(currentTabUrl).hostname;
} catch {
return currentTabUrl;
}
})()
: "Reading tab..."
: "Read this tab"}
</Button>
) : (
showDeepResearch && (
<Button
icon={SvgHourglass}
onClick={toggleDeepResearch}
variant="select"
selected={deepResearchEnabled}
foldable={!deepResearchEnabled}
disabled={disabled}
>
Deep Research
</Button>
)
)}
{selectedAgent &&
forcedToolIds.length > 0 &&
forcedToolIds.map((toolId) => {
const tool = selectedAgent.tools.find(
(tool) => tool.id === toolId
);
if (!tool) {
return null;
}
return (
<Button
key={toolId}
icon={getIconForAction(tool)}
onClick={() => {
setForcedToolIds(
forcedToolIds.filter((id) => id !== toolId)
);
}}
variant="select"
selected
disabled={disabled}
>
{tool.display_name}
</Button>
);
})}
</div>
</div>
{/* Bottom right controls */}
<div className="flex flex-row items-center gap-1">
{/* LLM popover - loads when ready */}
<div
data-testid="AppInputBar/llm-popover-trigger"
className={cn(controlsLoading && "invisible")}
>
<LLMPopover
llmManager={llmManager}
requiresImageInput={hasImageFiles}
disabled={disabled}
/>
</div>
{/* Submit button */}
<Button
id="onyx-chat-input-send-button"
icon={
isClassifying
? SimpleLoader
: chatState === "input"
? SvgArrowUp
: SvgStop
}
disabled={
(chatState === "input" && !message) ||
hasUploadingFiles ||
isClassifying
}
onClick={() => {
if (chatState == "streaming") {
stopGenerating();
} else if (message) {
onSubmit(message);
}
}}
/>
</div>
</div>
)}
{chatControls}
</div>
</Disabled>
);

View File

@@ -116,8 +116,6 @@ function ViewerOpenApiToolCard({ tool }: { tool: ToolSnapshot }) {
);
}
const EMPTY_DOCS: [] = [];
/**
* Floating ChatInputBar below the AgentViewerModal.
* On submit, navigates to the agent's chat with the message pre-filled.
@@ -137,14 +135,10 @@ function AgentChatInput({ agent, onSubmit }: AgentChatInputProps) {
chatState="input"
filterManager={filterManager}
selectedAgent={agent}
selectedDocuments={EMPTY_DOCS}
removeDocs={() => {}}
stopGenerating={() => {}}
handleFileUpload={() => {}}
toggleDocumentSidebar={() => {}}
currentSessionFileTokenCount={0}
availableContextTokens={Infinity}
retrievalEnabled={false}
deepResearchEnabled={false}
toggleDeepResearch={() => {}}
disabled={false}

View File

@@ -0,0 +1,20 @@
import { test, expect } from "@tests/e2e/fixtures/eeFeatures";
test.describe("EE Feature Redirect", () => {
test("redirects to /chat with toast when EE features are not licensed", async ({
page,
eeEnabled,
}) => {
test.skip(eeEnabled, "Redirect only happens without Enterprise license");
await page.goto("/admin/theme");
await expect(page).toHaveURL(/\/chat/, { timeout: 10_000 });
const toastContainer = page.getByTestId("toast-container");
await expect(toastContainer).toBeVisible({ timeout: 5_000 });
await expect(
toastContainer.getByText(/only accessible with a paid license/i)
).toBeVisible();
});
});

View File

@@ -1,4 +1,4 @@
import { test, expect } from "@playwright/test";
import { test, expect } from "@tests/e2e/fixtures/eeFeatures";
import { loginAs } from "@tests/e2e/utils/auth";
test.describe("Appearance Theme Settings @exclusive", () => {
@@ -12,24 +12,21 @@ test.describe("Appearance Theme Settings @exclusive", () => {
consentPrompt: "I agree to the terms",
};
test.beforeEach(async ({ page }) => {
test.beforeEach(async ({ page, eeEnabled }) => {
test.skip(
!eeEnabled,
"Enterprise license not active — skipping theme tests"
);
// Fresh session — the eeEnabled fixture already logged in to check the
// setting, so clear cookies and re-login for a clean test state.
await page.context().clearCookies();
await loginAs(page, "admin");
// Navigate first so localStorage is accessible (API-based login
// doesn't navigate, leaving the page on about:blank).
await page.goto("/admin/theme");
await page.waitForLoadState("networkidle");
// Skip the entire test when Enterprise features are not licensed.
// The /admin/theme page is gated behind ee_features_enabled and
// renders a license-required message instead of the settings form.
const eeLocked = page.getByText(
"This functionality requires an active Enterprise license."
);
if (await eeLocked.isVisible({ timeout: 1000 }).catch(() => false)) {
test.skip(true, "Enterprise license not active — skipping theme tests");
}
await expect(
page.locator('[data-label="application-name-input"]')
).toBeVisible({ timeout: 10_000 });
// Clear localStorage to ensure consent modal shows
await page.evaluate(() => {

View File

@@ -0,0 +1,43 @@
/**
* Playwright fixture that detects EE (Enterprise Edition) license state.
*
* Usage:
* ```ts
* import { test, expect } from "@tests/e2e/fixtures/eeFeatures";
*
* test("my EE-gated test", async ({ page, eeEnabled }) => {
* test.skip(!eeEnabled, "Requires active Enterprise license");
* // ... rest of test
* });
* ```
*
* The fixture:
* - Authenticates as admin
* - Fetches /api/settings to check ee_features_enabled
* - Provides a boolean to the test BEFORE any navigation happens
*
* This lets tests call test.skip() synchronously at the top, which is the
* correct Playwright pattern — never navigate then decide to skip.
*/
import { test as base, expect } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
export const test = base.extend<{
/** Whether EE features are enabled (valid enterprise license). */
eeEnabled: boolean;
}>({
eeEnabled: async ({ page }, use) => {
await loginAs(page, "admin");
const res = await page.request.get("/api/settings");
if (!res.ok()) {
// Fail open — if we can't determine, assume EE is not enabled
await use(false);
return;
}
const settings = await res.json();
await use(settings.ee_features_enabled === true);
},
});
export { expect };