Compare commits

..

6 Commits

175 changed files with 2276 additions and 9114 deletions

View File

@@ -47,8 +47,7 @@ jobs:
done
- name: Publish Helm charts to gh-pages
# NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43
uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit]
uses: stefanprodan/helm-gh-pages@0ad2bb377311d61ac04ad9eb6f252fb68e207260 # ratchet:stefanprodan/helm-gh-pages@v1.7.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
charts_dir: deployment/helm/charts

View File

@@ -35,7 +35,6 @@ jobs:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 5
steps:
- name: Checkout

View File

@@ -183,7 +183,6 @@ jobs:
- cherry-pick-to-latest-release
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout
@@ -233,7 +232,6 @@ jobs:
- cherry-pick-to-latest-release
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 10
steps:
- name: Checkout

View File

@@ -63,7 +63,7 @@ jobs:
targets: ${{ matrix.target }}
- name: Cache Cargo registry and build
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # zizmor: ignore[cache-poisoning]
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
with:
path: |
~/.cargo/bin/

View File

@@ -284,7 +284,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
@@ -626,7 +626,7 @@ jobs:
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}

View File

@@ -56,7 +56,7 @@ jobs:
- name: Cache mypy cache
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}

View File

@@ -31,7 +31,6 @@ jobs:
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-model-check"
- "extras=ecr-cache"
environment: ci-protected
timeout-minutes: 45
env:

View File

@@ -15,7 +15,6 @@ permissions:
jobs:
Deploy-Preview:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd

View File

@@ -13,6 +13,15 @@ jobs:
permissions:
id-token: write
timeout-minutes: 10
strategy:
matrix:
os-arch:
- { goos: "linux", goarch: "amd64" }
- { goos: "linux", goarch: "arm64" }
- { goos: "windows", goarch: "amd64" }
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
@@ -22,11 +31,9 @@ jobs:
enable-cache: false
version: "0.9.9"
- run: |
for goos in linux windows darwin; do
for goarch in amd64 arm64; do
GOOS="$goos" GOARCH="$goarch" uv build --wheel
done
done
GOOS="${{ matrix.os-arch.goos }}" \
GOARCH="${{ matrix.os-arch.goarch }}" \
uv build --wheel
working-directory: cli
- run: uv publish
working-directory: cli

View File

@@ -25,7 +25,6 @@ permissions:
jobs:
Deploy-Storybook:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
@@ -55,7 +54,6 @@ jobs:
needs: Deploy-Storybook
if: always() && needs.Deploy-Storybook.result == 'failure'
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 10
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4

View File

@@ -9,7 +9,6 @@ on:
jobs:
sync-foss:
runs-on: ubuntu-latest
environment: ci-protected
timeout-minutes: 45
permissions:
contents: read

View File

@@ -11,7 +11,6 @@ permissions:
jobs:
create-and-push-tag:
runs-on: ubuntu-slim
environment: ci-protected
timeout-minutes: 45
steps:

View File

@@ -24,16 +24,6 @@ When hardcoding a boolean variable to a constant value, remove the variable enti
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
## Nginx Routing — New Backend Routes
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
## Full vs Lite Deployments
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.

View File

@@ -122,7 +122,7 @@ repos:
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
hooks:
- id: golangci-lint
language_version: "1.26.1"
language_version: "1.26.0"
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit

View File

@@ -35,7 +35,7 @@ Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep
> [!TIP]
> Run Onyx with one command (or see deployment section below):
> ```
> curl -fsSL https://onyx.app/install_onyx.sh | bash
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
> ```
****

View File

@@ -474,8 +474,6 @@ def connector_permission_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_connector=True,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(

View File

@@ -8,7 +8,6 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import HierarchyNode
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call
@@ -106,11 +105,9 @@ def _get_slack_document_access(
slack_connector: SlackConnector,
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
callback: IndexingHeartbeatInterface | None,
indexing_start: SecondsSinceUnixEpoch | None = None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback,
start=indexing_start,
callback=callback
)
for doc_metadata_batch in slim_doc_generator:
@@ -183,15 +180,9 @@ def slack_doc_sync(
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.set_credentials_provider(provider)
indexing_start_ts: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
yield from _get_slack_document_access(
slack_connector=slack_connector,
slack_connector,
channel_permissions=channel_permissions,
callback=callback,
indexing_start=indexing_start_ts,
)

View File

@@ -6,7 +6,6 @@ from onyx.access.models import ElementExternalAccess
from onyx.access.models import ExternalAccess
from onyx.access.models import NodeExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
@@ -41,19 +40,10 @@ def generic_doc_sync(
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
indexing_start: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
start=indexing_start,
callback=callback,
):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -1,8 +1,19 @@
import threading
import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
@@ -148,3 +159,114 @@ class ChatStateContainer:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
with event streaming capabilities.
The wrapped function should accept emitter as first arg and use it to emit
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
)
for packet in packets:
# Process packets
pass
"""
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)
# Run the function in a background thread
thread = run_in_background(run_with_exception_capture)
pkt: Packet | None = None
last_turn_index = 0 # Track the highest turn_index seen for stop packet
last_cancel_check = time.monotonic()
cancel_check_interval = 0.3 # Check for cancellation every 300ms
try:
while True:
# Poll queue with 300ms timeout for natural stop signal checking
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
try:
pkt = emitter.bus.get(timeout=0.3)
except Empty:
if not is_connected():
# Stop signal detected
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = time.monotonic()
continue
if pkt is not None:
# Track the highest turn_index for the stop packet
if pkt.placement and pkt.placement.turn_index > last_turn_index:
last_turn_index = pkt.placement.turn_index
if isinstance(pkt.obj, OverallStop):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
# Check for cancellation periodically even when packets are flowing
# This ensures stop signal is checked during active streaming
current_time = time.monotonic()
if current_time - last_cancel_check >= cancel_check_interval:
if not is_connected():
# Stop signal detected during streaming
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = current_time
finally:
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -1,40 +1,19 @@
import threading
from queue import Queue
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
"""Use this inside tools to emit arbitrary UI progress."""
Tags every packet with ``model_index`` and places it on ``merged_queue``
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
Args:
merged_queue: Shared queue owned by ``_run_models``.
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
drain_done: Optional event set by ``_run_models`` when the drain loop
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
immediately so worker threads can exit fast.
"""
def __init__(
self,
merged_queue: Queue[tuple[int, Packet | Exception | object]],
model_idx: int = 0,
drain_done: threading.Event | None = None,
) -> None:
self._model_idx = model_idx
self._merged_queue = merged_queue
self._drain_done = drain_done
def __init__(self, bus: Queue):
self.bus = bus
def emit(self, packet: Packet) -> None:
if self._drain_done is not None and self._drain_done.is_set():
return
base = packet.placement or Placement(turn_index=0)
tagged = Packet(
placement=base.model_copy(update={"model_index": self._model_idx}),
obj=packet.obj,
)
self._merged_queue.put((self._model_idx, tagged))
self.bus.put(packet) # Thread-safe
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter

File diff suppressed because it is too large Load Diff

View File

@@ -44,31 +44,6 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
# User Facing Features Configs
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
# Hard ceiling for the admin-configurable file upload size (in MB).
# Self-hosted customers can raise or lower this via the environment variable.
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
if _raw_max_upload_size_mb < 0:
logger.warning(
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
_raw_max_upload_size_mb,
)
_raw_max_upload_size_mb = 250
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
# Default fallback for the per-user file upload size limit (in MB) when no
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
# runtime so this never silently exceeds the hard ceiling.
_raw_default_upload_size_mb = int(
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
)
if _raw_default_upload_size_mb < 0:
logger.warning(
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
_raw_default_upload_size_mb,
)
_raw_default_upload_size_mb = 100
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
) # 1 day
@@ -86,6 +61,17 @@ CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# Maximum upload size for a single user file (chat/projects) in MB.
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
@@ -805,10 +791,6 @@ MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# The maximum number of chunks that can be held for 1 document processing batch
# The purpose of this is to set an upper bound on memory usage
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"

View File

@@ -890,8 +890,8 @@ class ConfluenceConnector(
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
@@ -915,8 +915,8 @@ class ConfluenceConnector(
self.confluence_client, doc_id, restrictions, ancestors
) or space_level_access_info.get(page_space_key)
# Query pages (with optional time filtering for indexing_start)
page_query = self._construct_page_cql_query(start, end)
# Query pages
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -950,9 +950,7 @@ class ConfluenceConnector(
# Query attachments for each page
page_hierarchy_node_yielded = False
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
attachment_query = self._construct_attachment_query(_get_page_id(page))
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,

View File

@@ -1765,11 +1765,7 @@ class SharepointConnector(
checkpoint.current_drive_delta_next_link = None
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateSlimDocumentOutput:
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
@@ -1790,9 +1786,7 @@ class SharepointConnector(
# Process site documents if flag is True
if self.include_site_documents:
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
site_descriptor=site_descriptor,
start=start,
end=end,
site_descriptor=site_descriptor
):
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
@@ -1847,9 +1841,7 @@ class SharepointConnector(
# Process site pages if flag is True
if self.include_site_pages:
site_pages = self._fetch_site_pages(
site_descriptor, start=start, end=end
)
site_pages = self._fetch_site_pages(site_descriptor)
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
@@ -2573,22 +2565,12 @@ class SharepointConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
) -> GenerateSlimDocumentOutput:
start_dt = (
datetime.fromtimestamp(start, tz=timezone.utc)
if start is not None
else None
)
end_dt = (
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
)
yield from self._fetch_slim_documents_from_sharepoint(
start=start_dt,
end=end_dt,
)
yield from self._fetch_slim_documents_from_sharepoint()
if __name__ == "__main__":

View File

@@ -516,8 +516,6 @@ def _get_all_doc_ids(
] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
workspace_url: str | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
@@ -548,8 +546,6 @@ def _get_all_doc_ids(
client=client,
channel=channel,
callback=callback,
oldest=str(start) if start else None, # 0.0 -> None intentionally
latest=str(end) if end is not None else None,
)
for message_batch in channel_message_batches:
@@ -851,8 +847,8 @@ class SlackConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
@@ -865,8 +861,6 @@ class SlackConnector(
msg_filter_func=self.msg_filter_func,
callback=callback,
workspace_url=self._workspace_url,
start=start,
end=end,
)
def _load_from_checkpoint(

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
@@ -132,32 +131,47 @@ def get_chat_sessions_by_user(
if before is not None:
stmt = stmt.where(ChatSession.time_updated < before)
if limit:
stmt = stmt.limit(limit)
if project_id is not None:
stmt = stmt.where(ChatSession.project_id == project_id)
elif only_non_project_chats:
stmt = stmt.where(ChatSession.project_id.is_(None))
if not include_failed_chats:
non_system_message_exists_subq = (
exists()
.where(ChatMessage.chat_session_id == ChatSession.id)
.where(ChatMessage.message_type != MessageType.SYSTEM)
.correlate(ChatSession)
)
# Leeway for newly created chats that don't have messages yet
time = datetime.now(timezone.utc) - timedelta(minutes=5)
recently_created = ChatSession.time_created >= time
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
# When filtering out failed chats, we apply the limit in Python after
# filtering rather than in SQL, since the post-filter may remove rows.
if limit and include_failed_chats:
stmt = stmt.limit(limit)
result = db_session.execute(stmt)
chat_sessions = result.scalars().all()
chat_sessions = list(result.scalars().all())
return list(chat_sessions)
if not include_failed_chats and chat_sessions:
# Filter out "failed" sessions (those with only SYSTEM messages)
# using a separate efficient query instead of a correlated EXISTS
# subquery, which causes full sequential scans of chat_message.
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
if session_ids:
valid_session_ids_stmt = (
select(ChatMessage.chat_session_id)
.where(ChatMessage.chat_session_id.in_(session_ids))
.where(ChatMessage.message_type != MessageType.SYSTEM)
.distinct()
)
valid_session_ids = set(
db_session.execute(valid_session_ids_stmt).scalars().all()
)
chat_sessions = [
cs
for cs in chat_sessions
if cs.time_created >= leeway or cs.id in valid_session_ids
]
if limit:
chat_sessions = chat_sessions[:limit]
return chat_sessions
def delete_orphaned_search_docs(db_session: Session) -> None:
@@ -617,92 +631,6 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -925,8 +853,6 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -5,7 +5,6 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -67,7 +66,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -207,7 +206,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -227,8 +226,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -210,10 +209,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.

View File

@@ -1,12 +1,11 @@
import json
from collections.abc import Iterable
from collections import defaultdict
from typing import Any
import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -352,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -648,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
Groups chunks by document ID and for each document, deletes existing
chunks and indexes the new chunks in bulk.
@@ -674,34 +673,29 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
assert len(doc_chunks) > 0, "doc_chunks is empty"
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
]
onyx_document: Document = doc_chunks[0].source_document
onyx_document: Document = chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -710,43 +704,22 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
_flush_chunks(current_chunks)
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
document_indexing_results.append(document_insertion_record)
return document_indexing_results

View File

@@ -6,7 +6,6 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -462,7 +461,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,8 +1,6 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -10,7 +8,6 @@ import httpx
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
from onyx.configs.app_configs import RERANK_COUNT
from onyx.configs.chat_configs import DOC_TIME_DECAY
@@ -321,7 +318,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -341,31 +338,22 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
existing_docs: set[str] = set()
@@ -421,16 +409,8 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
for chunk_batch in batch_generator(
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
):
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self._index_name,
@@ -439,6 +419,10 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],

View File

@@ -44,7 +44,6 @@ KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
"Unable to read workbook: could not read stylesheet from None",
"Colors must be aRGB hex values",
]

View File

@@ -19,8 +19,7 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
@@ -86,21 +85,14 @@ class DocumentIndexingBatchAdapter:
) as transaction:
yield transaction
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: DocumentBatchPrepareContext,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> "DocumentChunkEnricher":
"""Do all DB lookups once and return a per-chunk enricher."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
no_access = DocumentAccess.build(
user_emails=[],
@@ -110,30 +102,67 @@ class DocumentIndexingBatchAdapter:
is_public=False,
)
return DocumentChunkEnricher(
doc_id_to_access_info=get_access_for_documents(
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
doc_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
),
doc_id_to_document_set={
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=self.db_session
)
},
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
),
id_to_boost_map=context.id_to_boost_map,
doc_id_to_previous_chunk_cnt={
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
},
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
no_access=no_access,
tenant_id=tenant_id,
)
}
doc_id_to_previous_chunk_cnt: dict[str, int] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
db_session=self.db_session,
)
}
doc_id_to_new_chunk_cnt: dict[str, int] = {
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks_with_embeddings:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
# Get ancestor hierarchy node IDs for each document
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
context.updatable_docs, tenant_id
)
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
document_sets=set(
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
chunk.source_document.id
],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
user_file_id_to_raw_text={},
user_file_id_to_token_count={},
)
def _get_ancestor_ids_for_documents(
@@ -174,7 +203,7 @@ class DocumentIndexingBatchAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
enrichment: ChunkEnrichmentContext,
result: BuildMetadataAwareChunksResult,
) -> None:
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
updatable_ids = [doc.id for doc in context.updatable_docs]
@@ -198,7 +227,7 @@ class DocumentIndexingBatchAdapter:
update_docs_chunk_count__no_commit(
document_ids=updatable_ids,
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
db_session=self.db_session,
)
@@ -220,52 +249,3 @@ class DocumentIndexingBatchAdapter:
)
self.db_session.commit()
class DocumentChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
def __init__(
self,
doc_id_to_access_info: dict[str, DocumentAccess],
doc_id_to_document_set: dict[str, list[str]],
doc_id_to_ancestor_ids: dict[str, list[int]],
id_to_boost_map: dict[str, int],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._doc_id_to_access_info = doc_id_to_access_info
self._doc_id_to_document_set = doc_id_to_document_set
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
self._id_to_boost_map = id_to_boost_map
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._doc_id_to_access_info.get(
chunk.source_document.id, self._no_access
),
document_sets=set(
self._doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
self._id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in self._id_to_boost_map
else DEFAULT_BOOST
),
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
chunk.source_document.id
],
)

View File

@@ -1,9 +1,6 @@
from __future__ import annotations
import contextlib
import datetime
import time
from collections import defaultdict
from collections.abc import Generator
from uuid import UUID
@@ -27,13 +24,11 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
from onyx.indexing.models import ChunkEnrichmentContext
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
@@ -106,20 +101,13 @@ class UserFileIndexingAdapter:
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
)
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: DocumentBatchPrepareContext,
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> UserFileChunkEnricher:
"""Do all DB lookups and pre-compute file metadata from chunks."""
updatable_ids = [doc.id for doc in context.updatable_docs]
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
content_by_file: dict[str, list[str]] = defaultdict(list)
for chunk in chunks:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
content_by_file[chunk.source_document.id].append(chunk.content)
context: DocumentBatchPrepareContext,
) -> BuildMetadataAwareChunksResult:
no_access = DocumentAccess.build(
user_emails=[],
@@ -129,6 +117,7 @@ class UserFileIndexingAdapter:
is_public=False,
)
updatable_ids = [doc.id for doc in context.updatable_docs]
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -149,6 +138,17 @@ class UserFileIndexingAdapter:
)
}
user_file_id_to_new_chunk_cnt: dict[str, int] = {
user_file_id: len(
[
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
)
for user_file_id in updatable_ids
}
# Initialize tokenizer used for token count calculation
try:
llm = get_default_llm()
@@ -163,30 +163,46 @@ class UserFileIndexingAdapter:
user_file_id_to_raw_text: dict[str, str] = {}
user_file_id_to_token_count: dict[str, int | None] = {}
for user_file_id in updatable_ids:
contents = content_by_file.get(user_file_id)
if contents:
combined_content = " ".join(contents)
user_file_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == user_file_id
]
if user_file_chunks:
combined_content = " ".join(
[chunk.content for chunk in user_file_chunks]
)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count: int = (
count_tokens(combined_content, llm_tokenizer)
if llm_tokenizer
else 0
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
)
user_file_id_to_token_count[str(user_file_id)] = token_count
else:
user_file_id_to_raw_text[str(user_file_id)] = ""
user_file_id_to_token_count[str(user_file_id)] = None
return UserFileChunkEnricher(
user_file_id_to_access=user_file_id_to_access,
user_file_id_to_project_ids=user_file_id_to_project_ids,
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
document_sets=set(),
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
return BuildMetadataAwareChunksResult(
chunks=access_aware_chunks,
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
user_file_id_to_raw_text=user_file_id_to_raw_text,
user_file_id_to_token_count=user_file_id_to_token_count,
no_access=no_access,
tenant_id=tenant_id,
)
def _notify_assistant_owners_if_files_ready(
@@ -230,9 +246,8 @@ class UserFileIndexingAdapter:
context: DocumentBatchPrepareContext,
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
filtered_documents: list[Document], # noqa: ARG002
enrichment: ChunkEnrichmentContext,
result: BuildMetadataAwareChunksResult,
) -> None:
assert isinstance(enrichment, UserFileChunkEnricher)
user_file_ids = [doc.id for doc in context.updatable_docs]
user_files = (
@@ -248,10 +263,8 @@ class UserFileIndexingAdapter:
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
str(user_file.id), 0
)
user_file.token_count = enrichment.user_file_id_to_token_count[
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
user_file.token_count = result.user_file_id_to_token_count[
str(user_file.id)
]
@@ -263,54 +276,8 @@ class UserFileIndexingAdapter:
# Store the plaintext in the file store for faster retrieval
# NOTE: this creates its own session to avoid committing the overall
# transaction.
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
store_user_file_plaintext(
user_file_id=UUID(user_file_id),
plaintext_content=raw_text,
)
class UserFileChunkEnricher:
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
def __init__(
self,
user_file_id_to_access: dict[str, DocumentAccess],
user_file_id_to_project_ids: dict[str, list[int]],
user_file_id_to_persona_ids: dict[str, list[int]],
doc_id_to_previous_chunk_cnt: dict[str, int],
doc_id_to_new_chunk_cnt: dict[str, int],
user_file_id_to_raw_text: dict[str, str],
user_file_id_to_token_count: dict[str, int | None],
no_access: DocumentAccess,
tenant_id: str,
) -> None:
self._user_file_id_to_access = user_file_id_to_access
self._user_file_id_to_project_ids = user_file_id_to_project_ids
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
self._no_access = no_access
self._tenant_id = tenant_id
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
self.user_file_id_to_raw_text = user_file_id_to_raw_text
self.user_file_id_to_token_count = user_file_id_to_token_count
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk:
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=self._user_file_id_to_access.get(
chunk.source_document.id, self._no_access
),
document_sets=set(),
user_project=self._user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
personas=self._user_file_id_to_persona_ids.get(
chunk.source_document.id, []
),
boost=DEFAULT_BOOST,
tenant_id=self._tenant_id,
aggregated_chunk_boost_factor=score,
)

View File

@@ -1,89 +0,0 @@
import pickle
import shutil
import tempfile
from collections.abc import Iterator
from pathlib import Path
from onyx.indexing.models import IndexChunk
class ChunkBatchStore:
"""Manages serialization of embedded chunks to a temporary directory.
Owns the temp directory lifetime and provides save/load/stream/scrub
operations.
Use as a context manager to ensure cleanup::
with ChunkBatchStore() as store:
store.save(chunks, batch_idx=0)
for chunk in store.stream():
...
"""
_EXT = ".pkl"
def __init__(self) -> None:
self._tmpdir: Path | None = None
# -- context manager -----------------------------------------------------
def __enter__(self) -> "ChunkBatchStore":
self._tmpdir = Path(tempfile.mkdtemp(prefix="onyx_embeddings_"))
return self
def __exit__(self, *_exc: object) -> None:
if self._tmpdir is not None:
shutil.rmtree(self._tmpdir, ignore_errors=True)
self._tmpdir = None
@property
def _dir(self) -> Path:
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
return self._tmpdir
# -- storage primitives --------------------------------------------------
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
"""Serialize a batch of embedded chunks to disk."""
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
pickle.dump(chunks, f)
def _load(self, batch_file: Path) -> list[IndexChunk]:
"""Deserialize a batch of embedded chunks from a file."""
with open(batch_file, "rb") as f:
return pickle.load(f)
def _batch_files(self) -> list[Path]:
"""Return batch files sorted by numeric index."""
return sorted(
self._dir.glob(f"batch_*{self._EXT}"),
key=lambda p: int(p.stem.removeprefix("batch_")),
)
# -- higher-level operations ---------------------------------------------
def stream(self) -> Iterator[IndexChunk]:
"""Yield all chunks across all batch files.
Each call returns a fresh generator, so the data can be iterated
multiple times (e.g. once per document index).
"""
for batch_file in self._batch_files():
yield from self._load(batch_file)
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
When a document fails embedding in batch N, earlier batches may
already contain successfully embedded chunks for that document.
This ensures the output is all-or-nothing per document.
"""
for batch_file in self._batch_files():
batch_chunks = self._load(batch_file)
cleaned = [
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
]
if len(cleaned) != len(batch_chunks):
with open(batch_file, "wb") as f:
pickle.dump(cleaned, f)

View File

@@ -1,8 +1,5 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Protocol
from pydantic import BaseModel
@@ -12,7 +9,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
@@ -47,12 +43,10 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_store.file_store import get_default_file_store
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
@@ -69,7 +63,6 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -98,20 +91,6 @@ class IndexingPipelineResult(BaseModel):
failures: list[ConnectorFailure]
@classmethod
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
return cls(
new_docs=0,
total_docs=total_docs,
total_chunks=0,
failures=[],
)
class ChunkEmbeddingResult(BaseModel):
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
connector_failures: list[ConnectorFailure]
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -160,110 +139,6 @@ def _upsert_documents_in_db(
)
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
"""Extract document IDs from a list of connector failures."""
return {f.failed_document.document_id for f in failures if f.failed_document}
def _embed_chunks_to_store(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
store: ChunkBatchStore,
) -> ChunkEmbeddingResult:
"""Embed chunks in batches, spilling each batch to *store*.
If a document fails embedding in any batch, its chunks are excluded from
all batches (including earlier ones already written) so that the output
is all-or-nothing per document.
"""
successful_chunk_ids: list[tuple[int, str]] = []
all_embedding_failures: list[ConnectorFailure] = []
# Track failed doc IDs across all batches so that a failure in batch N
# causes chunks for that doc to be skipped in batch N+1 and stripped
# from earlier batches.
all_failed_doc_ids: set[str] = set()
for batch_idx, chunk_batch in enumerate(
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
):
# Skip chunks belonging to documents that failed in earlier batches.
chunk_batch = [
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
]
if not chunk_batch:
continue
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
chunks=chunk_batch,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
all_embedding_failures.extend(embedding_failures)
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
# Only keep successfully embedded chunks for non-failed docs.
chunks_with_embeddings = [
c
for c in chunks_with_embeddings
if c.source_document.id not in all_failed_doc_ids
]
successful_chunk_ids.extend(
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
)
store.save(chunks_with_embeddings, batch_idx)
del chunks_with_embeddings
# Scrub earlier batches for docs that failed in later batches.
if all_failed_doc_ids:
store.scrub_failed_docs(all_failed_doc_ids)
successful_chunk_ids = [
(chunk_id, doc_id)
for chunk_id, doc_id in successful_chunk_ids
if doc_id not in all_failed_doc_ids
]
return ChunkEmbeddingResult(
successful_chunk_ids=successful_chunk_ids,
connector_failures=all_embedding_failures,
)
@contextmanager
def embed_and_stream(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str,
request_id: str | None,
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
"""Embed chunks to disk and yield a ``(result, store)`` pair.
The store owns the temp directory — files are cleaned up when the context
manager exits.
Usage::
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
for chunk in store.stream():
...
"""
with ChunkBatchStore() as store:
result = _embed_chunks_to_store(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
store=store,
)
yield result, store
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@@ -762,29 +637,6 @@ def add_contextual_summaries(
return chunks
def _verify_indexing_completeness(
insertion_records: list[DocumentInsertionRecord],
write_failures: list[ConnectorFailure],
embedding_failed_doc_ids: set[str],
updatable_ids: list[str],
document_index_name: str,
) -> None:
"""Verify that every updatable document was either indexed or reported as failed."""
all_returned_doc_ids = (
{r.document_id for r in insertion_records}
| {f.failed_document.document_id for f in write_failures if f.failed_document}
| embedding_failed_doc_ids
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
f"This should never happen. "
f"This occured for document index {document_index_name}"
)
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -820,7 +672,12 @@ def index_doc_batch(
filtered_documents = filter_fnc(document_batch)
context = adapter.prepare(filtered_documents, ignore_time_skip)
if not context:
return IndexingPipelineResult.empty(len(filtered_documents))
return IndexingPipelineResult(
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
@@ -859,99 +716,117 @@ def index_doc_batch(
)
logger.debug("Starting embedding")
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
embedding_result,
chunk_store,
):
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk_id,
document_id=document_id,
boost_score=1.0,
)
for chunk_id, document_id in embedding_result.successful_chunk_ids
]
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=request_id,
)
if chunks
else ([], [])
)
embedding_failed_doc_ids = _get_failed_doc_ids(
embedding_result.connector_failures
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
updatable_ids = [doc.id for doc in context.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk.chunk_id,
document_id=chunk.source_document.id,
boost_score=score,
)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
# with Vespa can occur.
with adapter.lock_context(context.updatable_docs):
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync via the celery queue
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=chunks_with_embeddings,
chunk_content_scores=chunk_content_scores,
tenant_id=tenant_id,
context=context,
)
# Filter to only successfully embedded chunks so
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
embedded_chunks = [
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
]
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
# Acquires a lock on the documents so that no other process can modify
# them. Not needed until here, since this is when the actual race
# condition with vector db can occur.
with adapter.lock_context(context.updatable_docs):
enricher = adapter.prepare_enrichment(
context=context,
tenant_id=tenant_id,
chunks=embedded_chunks,
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
for document_index in document_indices:
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=result.chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
),
)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
tenant_id=tenant_id,
large_chunks_enabled=chunker.enable_large_chunks,
)
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
None
)
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
None
)
for document_index in document_indices:
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
for chunk in chunk_store.stream():
yield enricher.enrich_chunk(chunk, 1.0)
insertion_records, write_failures = (
write_chunks_to_vector_db_with_backoff(
document_index=document_index,
make_chunks=_enriched_stream,
index_batch_params=index_batch_params,
)
all_returned_doc_ids: set[str] = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
_verify_indexing_completeness(
insertion_records=insertion_records,
write_failures=write_failures,
embedding_failed_doc_ids=embedding_failed_doc_ids,
updatable_ids=updatable_ids,
document_index_name=document_index.__class__.__name__,
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
enrichment=enricher,
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"This occured for document index {document_index.__class__.__name__}"
)
# We treat the first document index we got as the primary one used
# for reporting the state of indexing.
if primary_doc_idx_insertion_records is None:
primary_doc_idx_insertion_records = insertion_records
if primary_doc_idx_vector_db_write_failures is None:
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
adapter.post_index(
context=context,
updatable_chunk_data=updatable_chunk_data,
filtered_documents=filtered_documents,
result=result,
)
assert primary_doc_idx_insertion_records is not None
assert primary_doc_idx_vector_db_write_failures is not None
return IndexingPipelineResult(
new_docs=sum(
1 for r in primary_doc_idx_insertion_records if not r.already_existed
new_docs=len(
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
),
total_docs=len(filtered_documents),
total_chunks=len(embedding_result.successful_chunk_ids),
failures=primary_doc_idx_vector_db_write_failures
+ embedding_result.connector_failures,
total_chunks=len(chunks_with_embeddings),
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
)

View File

@@ -235,16 +235,12 @@ class UpdatableChunkData(BaseModel):
boost_score: float
class ChunkEnrichmentContext(Protocol):
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
and provides per-chunk enrichment."""
class BuildMetadataAwareChunksResult(BaseModel):
chunks: list[DocMetadataAwareIndexChunk]
doc_id_to_previous_chunk_cnt: dict[str, int]
doc_id_to_new_chunk_cnt: dict[str, int]
def enrich_chunk(
self, chunk: IndexChunk, score: float
) -> DocMetadataAwareIndexChunk: ...
user_file_id_to_raw_text: dict[str, str]
user_file_id_to_token_count: dict[str, int | None]
class IndexingBatchAdapter(Protocol):
@@ -258,24 +254,18 @@ class IndexingBatchAdapter(Protocol):
) -> Generator[TransactionalContext, None, None]:
"""Provide a transaction/row-lock context for critical updates."""
def prepare_enrichment(
def build_metadata_aware_chunks(
self,
context: "DocumentBatchPrepareContext",
chunks_with_embeddings: list[IndexChunk],
chunk_content_scores: list[float],
tenant_id: str,
chunks: list[DocAwareChunk],
) -> ChunkEnrichmentContext:
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
Precondition: ``chunks`` have already been through the embedding step
(i.e. they are ``IndexChunk`` instances with populated embeddings,
passed here as the base ``DocAwareChunk`` type).
"""
...
context: "DocumentBatchPrepareContext",
) -> BuildMetadataAwareChunksResult: ...
def post_index(
self,
context: "DocumentBatchPrepareContext",
updatable_chunk_data: list[UpdatableChunkData],
filtered_documents: list[Document],
enrichment: ChunkEnrichmentContext,
result: BuildMetadataAwareChunksResult,
) -> None: ...

View File

@@ -1,9 +1,6 @@
import time
from collections.abc import Callable
from collections.abc import Iterable
from collections import defaultdict
from http import HTTPStatus
from itertools import chain
from itertools import groupby
import httpx
@@ -31,22 +28,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
def write_chunks_to_vector_db_with_backoff(
document_index: DocumentIndex,
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
vector DB interface assumes that all chunks for a single document are present. The
chunks must also be in contiguous batches
vector DB interface assumes that all chunks for a single document are present.
"""
# first try to write the chunks to the vector db
try:
return (
list(
document_index.index(
chunks=make_chunks(),
chunks=chunks,
index_batch_params=index_batch_params,
)
),
@@ -63,23 +60,14 @@ def write_chunks_to_vector_db_with_backoff(
# wait a couple seconds just to give the vector db a chance to recover
time.sleep(2)
# try writing each doc one by one
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
for chunk in chunks:
chunks_for_docs[chunk.source_document.id].append(chunk)
insertion_records: list[DocumentInsertionRecord] = []
failures: list[ConnectorFailure] = []
def key(chunk: DocMetadataAwareIndexChunk) -> str:
return chunk.source_document.id
seen_doc_ids: set[str] = set()
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
if doc_id in seen_doc_ids:
raise RuntimeError(
f"Doc chunks are not arriving in order. Current doc_id={doc_id}, seen_doc_ids={list(seen_doc_ids)}"
)
seen_doc_ids.add(doc_id)
first_chunk = next(chunks_for_doc)
chunks_for_doc = chain([first_chunk], chunks_for_doc)
for doc_id, chunks_for_doc in chunks_for_docs.items():
try:
insertion_records.extend(
document_index.index(
@@ -99,7 +87,9 @@ def write_chunks_to_vector_db_with_backoff(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=first_chunk.get_link(),
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
),
failure_message=str(e),
exception=e,

View File

@@ -8,24 +8,6 @@ from pydantic import BaseModel
class LLMOverride(BaseModel):
"""Per-request LLM settings that override persona defaults.
All fields are optional — only the fields that differ from the persona's
configured LLM need to be supplied. Used both over the wire (API requests)
and for multi-model comparison, where one override is supplied per model.
Attributes:
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
When ``None``, the persona's default provider is used.
model_version: Specific model version string (e.g. ``"gpt-4o"``).
When ``None``, the persona's default model is used.
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
persona's default temperature is used.
display_name: Human-readable label shown in the UI for this model,
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
when not set.
"""
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None

View File

@@ -175,32 +175,6 @@ def get_tokenizer(
return _check_tokenizer_cache(provider_type, model_name)
# Max characters per encode() call.
_ENCODE_CHUNK_SIZE = 500_000
def count_tokens(
text: str,
tokenizer: BaseTokenizer,
token_limit: int | None = None,
) -> int:
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
If token_limit is provided and the text is large enough to require
multiple chunks (> 500k chars), stops early once the count exceeds it.
When early-exiting, the returned value exceeds token_limit but may be
less than the true full token count.
"""
if len(text) <= _ENCODE_CHUNK_SIZE:
return len(tokenizer.encode(text))
total = 0
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
if token_limit is not None and total > token_limit:
return total # Already over — skip remaining chunks
return total
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:

View File

@@ -3844,9 +3844,9 @@
}
},
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
"version": "5.0.5",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
"integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==",
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
"license": "MIT",
"dependencies": {
"balanced-match": "^4.0.2"
@@ -4224,9 +4224,9 @@
}
},
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -5007,9 +5007,9 @@
}
},
"node_modules/brace-expansion": {
"version": "1.1.13",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
"version": "1.1.12",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
"dev": true,
"license": "MIT",
"dependencies": {

View File

@@ -123,8 +123,9 @@ def _validate_endpoint(
(not reachable — indicates the api_key is invalid).
Timeout handling:
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
timeout (operator should consider increasing timeout_seconds).
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)

View File

@@ -9,15 +9,20 @@ from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -156,8 +161,8 @@ def categorize_uploaded_files(
document formats (.pdf, .docx, …) and falls back to a text-detection
heuristic for unknown extensions (.py, .js, .rs, …).
- Uses default tokenizer to compute token length.
- If token length exceeds the admin-configured threshold, reject file.
- If extension unsupported or text cannot be extracted, reject file.
- If token length > threshold, reject file (unless threshold skip is enabled).
- If text cannot be extracted, reject file.
- Otherwise marked as acceptable.
"""
@@ -168,33 +173,36 @@ def categorize_uploaded_files(
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
# Derive limits from admin-configurable settings.
# For upload size: load_settings() resolves 0/None to a positive default.
# For token threshold: 0 means "no limit" (converted to None below).
settings = load_settings()
max_upload_size_mb = (
settings.user_file_max_upload_size_mb
) # always positive after load_settings()
max_upload_size_bytes = (
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
)
token_threshold_k = settings.file_token_count_threshold_k
token_threshold = (
token_threshold_k * 1000 if token_threshold_k else None
) # 0 → None = no limit
# Check if threshold checks should be skipped
skip_threshold = False
# Check global skip flag (works for both single-tenant and multi-tenant)
if SKIP_USERFILE_THRESHOLD:
skip_threshold = True
logger.info("Skipping userfile threshold check (global setting)")
# Check tenant-specific skip list (only applicable in multi-tenant)
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
try:
current_tenant_id = get_current_tenant_id()
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
if skip_threshold:
logger.info(
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
)
except RuntimeError as e:
logger.warning(f"Failed to get current tenant ID: {str(e)}")
for upload in files:
try:
filename = get_safe_filename(upload)
# Size limit is a hard safety cap.
if max_upload_size_bytes is not None and is_upload_too_large(
upload, max_upload_size_bytes
):
# Size limit is a hard safety cap and is enforced even when token
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
)
)
continue
@@ -216,11 +224,11 @@ def categorize_uploaded_files(
)
continue
if token_threshold is not None and token_count > token_threshold:
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {token_threshold_k}K token limit",
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
)
)
else:
@@ -261,14 +269,12 @@ def categorize_uploaded_files(
)
continue
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
if token_threshold is not None and token_count > token_threshold:
token_count = len(tokenizer.encode(text_content))
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {token_threshold_k}K token limit",
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
)
)
else:

View File

@@ -28,7 +28,6 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -62,8 +60,6 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -85,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -575,46 +570,6 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -705,30 +660,6 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -2,25 +2,11 @@ from pydantic import BaseModel
class Placement(BaseModel):
"""Coordinates that identify where a streaming packet belongs in the UI.
The frontend uses these fields to route each packet to the correct turn,
tool tab, agent sub-turn, and (in multi-model mode) response column.
Attributes:
turn_index: Monotonically increasing index of the iterative reasoning block
(e.g. tool call round) within this chat message. Lower values happened first.
tab_index: Disambiguates parallel tool calls within the same turn so each
tool's output can be displayed in its own tab.
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
top-level packets; an integer for tool-within-tool output.
model_index: Which model this packet belongs to. ``0`` for single-model
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
for pre-LLM setup packets (e.g. message ID info) that are yielded
before any Emitter runs.
"""
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -9,9 +9,7 @@ from onyx import __version__ as onyx_version
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import is_user_admin
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.constants import NotificationType
from onyx.db.engine.sql_engine import get_session
@@ -19,16 +17,10 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Notification
from onyx.server.settings.models import Settings
from onyx.server.settings.models import UserSettings
@@ -49,15 +41,6 @@ basic_router = APIRouter(prefix="/settings")
def admin_put_settings(
settings: Settings, _: User = Depends(current_admin_user)
) -> None:
if (
settings.user_file_max_upload_size_mb is not None
and settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
)
store_settings(settings)
@@ -100,16 +83,6 @@ def fetch_settings(
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
),
default_file_token_count_threshold_k=(
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
),
)

View File

@@ -2,19 +2,12 @@ from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import NotificationType
from onyx.configs.constants import QueryHistoryType
from onyx.db.models import Notification as NotificationDBModel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
class PageType(str, Enum):
CHAT = "chat"
@@ -85,12 +78,7 @@ class Settings(BaseModel):
# User Knowledge settings
user_knowledge_enabled: bool | None = True
user_file_max_upload_size_mb: int | None = Field(
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
)
file_token_count_threshold_k: int | None = Field(
default=None, ge=0 # thousands of tokens; None = context-aware default
)
user_file_max_upload_size_mb: int | None = None
# Connector settings
show_extra_connectors: bool | None = True
@@ -120,14 +108,3 @@ class UserSettings(Settings):
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
# Factory defaults so the frontend can show a "restore default" button.
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
default_file_token_count_threshold_k: int = Field(
default_factory=lambda: (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
)

View File

@@ -1,19 +1,13 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -57,36 +51,9 @@ def load_settings() -> Settings:
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
# Resolve context-aware defaults for token threshold.
# None = admin hasn't set a value yet → use context-aware default.
# 0 = admin explicitly chose "no limit" → preserve as-is.
if settings.file_token_count_threshold_k is None:
settings.file_token_count_threshold_k = (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
# Upload size: 0 and None are treated as "unset" (not "no limit") →
# fall back to min(configured default, hard ceiling).
if not settings.user_file_max_upload_size_mb:
settings.user_file_max_upload_size_mb = min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
)
# Clamp to env ceiling so stale KV values are capped even if the
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
# was already saved (api.py only guards new writes).
if (
settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
return settings

View File

@@ -1,4 +1,3 @@
import queue
import time
from collections.abc import Callable
from typing import Any
@@ -709,6 +708,7 @@ def run_research_agent_calls(
if __name__ == "__main__":
from queue import Queue
from uuid import uuid4
from onyx.chat.chat_state import ChatStateContainer
@@ -744,8 +744,8 @@ if __name__ == "__main__":
if user is None:
raise ValueError("No users found in database. Please create a user first.")
emitter_queue: queue.Queue = queue.Queue()
emitter = Emitter(merged_queue=emitter_queue)
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
state_container = ChatStateContainer()
tool_dict = construct_tools(
@@ -792,4 +792,4 @@ if __name__ == "__main__":
print(result.intermediate_report)
print("=" * 80)
print(f"Citations: {result.citation_mapping}")
print(f"Total packets emitted: {emitter_queue.qsize()}")
print(f"Total packets emitted: {bus.qsize()}")

View File

@@ -1,6 +1,5 @@
import csv
import json
import queue
import uuid
from io import BytesIO
from io import StringIO
@@ -12,6 +11,7 @@ import requests
from requests import JSONDecodeError
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
# Use a discard emitter if none provided (packets go nowhere)
# Use default emitter if none provided
if emitter is None:
emitter = Emitter(merged_queue=queue.Queue())
emitter = get_default_emitter()
return [
CustomTool(
@@ -367,7 +367,7 @@ if __name__ == "__main__":
tools = build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=openapi_schema,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
dynamic_schema_info=None,
)

View File

@@ -187,7 +187,7 @@ coloredlogs==15.0.1
# via onnxruntime
courlan==1.3.2
# via trafilatura
cryptography==46.0.6
cryptography==46.0.5
# via
# authlib
# google-auth
@@ -449,7 +449,7 @@ kombu==5.5.4
# via celery
kubernetes==31.0.0
# via onyx
langchain-core==1.2.22
langchain-core==1.2.11
# via onyx
langdetect==1.0.9
# via unstructured
@@ -735,7 +735,7 @@ pyee==13.0.0
# via playwright
pygithub==2.5.0
# via onyx
pygments==2.20.0
pygments==2.19.2
# via rich
pyjwt==2.12.0
# via

View File

@@ -97,7 +97,7 @@ comm==0.2.3
# via ipykernel
contourpy==1.3.3
# via matplotlib
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt
@@ -263,7 +263,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.7.2
onyx-devtools==0.7.1
# via onyx
openai==2.14.0
# via
@@ -349,7 +349,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pygments==2.20.0
pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers

View File

@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt

View File

@@ -92,7 +92,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt

View File

@@ -191,6 +191,25 @@ IGNORED_SYNCING_TENANT_LIST = (
else None
)
# Global flag to skip userfile threshold for all users/tenants
SKIP_USERFILE_THRESHOLD = (
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
)
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
)
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
[
tenant.strip()
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
if tenant.strip()
]
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
else None
)
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"

View File

@@ -1,6 +1,4 @@
import time
from datetime import datetime
from datetime import timezone
import pytest
@@ -19,10 +17,6 @@ PRIVATE_CHANNEL_USERS = [
"test_user_2@onyx-test.com",
]
# Predates any test workspace messages, so the result set should match
# the "no start time" case while exercising the oldest= parameter.
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
pytestmark = pytest.mark.usefixtures("enable_ee")
@@ -111,17 +105,15 @@ def test_load_from_checkpoint_access__private_channel(
],
indirect=True,
)
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
def test_slim_documents_access__public_channel(
slack_connector: SlackConnector,
start_ts: float | None,
) -> None:
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
if not slack_connector.client:
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=start_ts,
start=0.0,
end=time.time(),
)
@@ -157,7 +149,7 @@ def test_slim_documents_access__private_channel(
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=None,
start=0.0,
end=time.time(),
)

View File

@@ -27,13 +27,11 @@ def create_placement(
turn_index: int,
tab_index: int = 0,
sub_turn_index: int | None = None,
model_index: int | None = 0,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
model_index=model_index,
)

View File

@@ -1,7 +1,7 @@
"""
External dependency unit tests for UserFileIndexingAdapter metadata writing.
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
objects with both `user_project` and `personas` fields populated correctly
based on actual DB associations.
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
class TestAdapterWritesBothMetadataFields:
"""prepare_enrichment must populate user_project AND personas."""
"""build_metadata_aware_chunks must populate user_project AND personas."""
@patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
@@ -153,13 +153,15 @@ class TestAdapterWritesBothMetadataFields:
doc = chunk.source_document
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert aware_chunk.user_project == []
@@ -188,13 +190,15 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
assert len(result.chunks) == 1
aware_chunk = result.chunks[0]
assert project.id in aware_chunk.user_project
assert aware_chunk.personas == []
@@ -225,13 +229,14 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert persona.id in aware_chunk.personas
assert project.id in aware_chunk.user_project
@@ -256,13 +261,14 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert aware_chunk.personas == []
assert aware_chunk.user_project == []
@@ -294,11 +300,12 @@ class TestAdapterWritesBothMetadataFields:
updatable_docs=[chunk.source_document], id_to_boost_map={}
)
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id=TEST_TENANT_ID,
chunks=[chunk],
context=context,
)
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
aware_chunk = result.chunks[0]
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}

View File

@@ -6,7 +6,6 @@ These tests assume Vespa and OpenSearch are running.
import time
import uuid
from collections.abc import Generator
from collections.abc import Iterator
import httpx
import pytest
@@ -22,7 +21,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import DocMetadataAwareIndexChunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
@@ -203,25 +201,3 @@ class TestDocumentIndexNew:
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk(doc_id, chunk_id=i)
results = document_index.index(
chunks=chunk_gen(), indexing_metadata=metadata
)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False

View File

@@ -5,7 +5,6 @@ These tests assume Vespa and OpenSearch are running.
import time
from collections.abc import Generator
from collections.abc import Iterator
import pytest
@@ -167,29 +166,3 @@ class TestDocumentIndexOld:
batch_retrieval=True,
)
assert len(inference_chunks) == 0
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndex],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk("test_doc_gen", chunk_id=i)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
results = document_index.index(chunk_gen(), index_batch_params)
assert len(results) == 1
record = results.pop()
assert record.document_id == "test_doc_gen"
assert record.already_existed is False

View File

@@ -13,7 +13,6 @@ This test:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import patch
from uuid import uuid4
@@ -21,7 +20,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -138,7 +137,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -201,7 +200,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -276,7 +275,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -351,7 +350,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -459,7 +458,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -542,7 +541,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),

View File

@@ -8,7 +8,6 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
@@ -17,7 +16,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.db.models import OAuthAccount
from onyx.db.models import OAuthConfig
from onyx.db.models import Persona
@@ -175,7 +174,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -233,7 +232,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
)
@@ -285,7 +284,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
)
@@ -346,7 +345,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
)
@@ -417,7 +416,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
)
@@ -484,7 +483,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
)
@@ -537,7 +536,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=Emitter(merged_queue=queue.Queue()),
emitter=get_default_emitter(),
user=user,
llm=llm,
)

View File

@@ -1,173 +0,0 @@
"""Unit tests for the Emitter class.
All tests use the streaming mode (merged_queue required). Emitter has a single
code path — no standalone bus.
"""
import queue
from onyx.chat.emitter import Emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _placement(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
def _packet(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Packet:
"""Build a minimal valid packet with an OverallStop payload."""
return Packet(
placement=_placement(turn_index, tab_index, sub_turn_index),
obj=OverallStop(stop_reason="test"),
)
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
"""Return (emitter, queue) wired together."""
mq: queue.Queue = queue.Queue()
return Emitter(merged_queue=mq, model_idx=model_idx), mq
# ---------------------------------------------------------------------------
# Queue routing
# ---------------------------------------------------------------------------
class TestEmitterQueueRouting:
def test_emit_lands_on_merged_queue(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet())
assert not mq.empty()
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
item = mq.get_nowait()
assert isinstance(item, tuple)
assert len(item) == 2
def test_multiple_packets_delivered_fifo(self) -> None:
emitter, mq = _make_emitter()
p1 = _packet(turn_index=0)
p2 = _packet(turn_index=1)
emitter.emit(p1)
emitter.emit(p2)
_, t1 = mq.get_nowait()
_, t2 = mq.get_nowait()
assert t1.placement.turn_index == 0
assert t2.placement.turn_index == 1
# ---------------------------------------------------------------------------
# model_index tagging
# ---------------------------------------------------------------------------
class TestEmitterModelIndexTagging:
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
"""N=1: default model_idx=0, so packet gets model_index=0."""
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 0
def test_model_idx_one_tags_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 1
def test_model_idx_two_tags_packet(self) -> None:
"""Boundary: third model in a 3-model run."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 2
# ---------------------------------------------------------------------------
# Queue key
# ---------------------------------------------------------------------------
class TestEmitterQueueKey:
def test_key_equals_model_idx(self) -> None:
"""Drain loop uses the key to route packets; it must match model_idx."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 2
def test_n1_key_is_zero(self) -> None:
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0
# ---------------------------------------------------------------------------
# Placement field preservation
# ---------------------------------------------------------------------------
class TestEmitterPlacementPreservation:
def test_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(turn_index=5))
_, tagged = mq.get_nowait()
assert tagged.placement.turn_index == 5
def test_tab_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(tab_index=3))
_, tagged = mq.get_nowait()
assert tagged.placement.tab_index == 3
def test_sub_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=2))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index == 2
def test_sub_turn_index_none_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=None))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index is None
def test_packet_obj_is_not_modified(self) -> None:
"""The payload object must survive tagging untouched."""
emitter, mq = _make_emitter()
original_obj = OverallStop(stop_reason="sentinel")
pkt = Packet(placement=_placement(), obj=original_obj)
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert tagged.obj is original_obj
def test_different_obj_types_are_handled(self) -> None:
"""Any valid PacketObj type passes through correctly."""
emitter, mq = _make_emitter()
pkt = Packet(placement=_placement(), obj=ReasoningStart())
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert isinstance(tagged.obj, ReasoningStart)

View File

@@ -1,676 +0,0 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in handle_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
import time
from collections.abc import Generator
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.chat.models import StreamingError
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.utils.variable_functionality import global_version
@pytest.fixture(autouse=True)
def _restore_ee_version() -> Generator[None, None, None]:
"""Reset EE global state after each test.
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
(via the celery client import chain). Without this fixture, the EE flag stays
True for the rest of the session and breaks unrelated tests that mock Confluence
or other connectors and assume EE is disabled.
"""
original = global_version._is_ee
yield
global_version._is_ee = original
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
"""Return the first item yielded by handle_multi_model_stream."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
return next(gen)
# ---------------------------------------------------------------------------
# handle_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_yields_error(self) -> None:
"""Exactly 1 override is not multi-model — yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [_make_override()])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_four_overrides_yields_error(self) -> None:
"""4 overrides exceeds maximum — yields StreamingError."""
req = _make_request()
result = _first_from_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_zero_overrides_yields_error(self) -> None:
"""Empty override list yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_deep_research_yields_error(self) -> None:
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
req = _make_request(deep_research=True)
result = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
assert isinstance(result, StreamingError)
assert "not supported" in result.error
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override yields error, 2 overrides passes validation."""
req = _make_request()
# 1 override must yield a StreamingError
result = _first_from_stream(req, [_make_override()])
assert isinstance(
result, StreamingError
), "1 override should yield StreamingError"
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
# missing session, that's OK — validation itself passed)
try:
result2 = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
if isinstance(result2, StreamingError) and "2-3" in result2.error:
pytest.fail(
f"2 overrides should pass validation, got StreamingError: {result2.error}"
)
except Exception:
pass # Any non-validation error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"
# ---------------------------------------------------------------------------
# _run_models — drain loop behaviour
# ---------------------------------------------------------------------------
def _make_setup(n_models: int = 1) -> MagicMock:
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
setup = MagicMock()
setup.llms = [MagicMock() for _ in range(n_models)]
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
setup.check_is_connected = MagicMock(return_value=True)
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
setup.reserved_token_count = 100
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
# constructors inside _run_model — must be typed correctly for Pydantic.
setup.new_msg_req.deep_research = False
setup.new_msg_req.internal_search_filters = None
setup.new_msg_req.allowed_tool_ids = None
setup.new_msg_req.include_citations = True
setup.search_params.project_id_filter = None
setup.search_params.persona_id_filter = None
setup.bypass_acl = False
setup.slack_context = None
setup.available_files.user_file_ids = []
setup.available_files.chat_file_ids = []
setup.forced_tool_id = None
setup.simple_chat_history = []
setup.chat_session.id = uuid4()
setup.user_message.id = None
setup.custom_tool_additional_headers = None
setup.mcp_headers = None
return setup
_RUN_MODELS_PATCHES = [
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
]
def _run_models_collect(setup: MagicMock) -> list:
"""Drive _run_models to completion and return all yielded items."""
from onyx.chat.process_message import _run_models
return list(_run_models(setup, MagicMock(), MagicMock()))
class TestRunModels:
"""Tests for the _run_models worker-thread drain loop.
All external dependencies (LLM, DB, tools) are patched out. Worker threads
still run but return immediately since run_llm_loop is mocked.
"""
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
def emit_stop(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(stop_reason="complete"),
)
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert len(stops) == 1
stop_obj = stops[0].obj
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 0
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
def emit_one(**kwargs: Any) -> None:
# _model_idx is set by _run_model based on position in setup.llms
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 2
indices = {p.placement.model_index for p in reasoning}
assert indices == {0, 1}
def test_model_error_yields_streaming_error(self) -> None:
"""An exception inside a worker thread is surfaced as a StreamingError."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("intentional test failure")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
assert errors[0].error_code == "MODEL_ERROR"
assert "intentional test failure" in errors[0].error
def test_one_model_error_does_not_stop_other_models(self) -> None:
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
emitter = kwargs["emitter"]
# _model_idx is always int (0 for N=1, 0/1/2… for N>1)
if emitter._model_idx == 0:
raise RuntimeError("model 0 failed")
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=fail_model_0_succeed_model_1,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 1
def test_cancellation_yields_user_cancelled_stop(self) -> None:
"""If check_is_connected returns False, drain loop emits user_cancelled."""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert any(
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
for s in stops
)
def test_completion_handle_called_on_disconnect(self) -> None:
"""llm_loop_completion_handle must still be called even when user disconnects.
Regression test for the disconnect-cleanup bug: the old
run_chat_loop_with_state_containers always called completion_callback in
its finally block (even on disconnect) so the DB message was updated from
the TERMINATED placeholder to a partial answer. The new _run_models must
replicate this — otherwise the integration test
test_send_message_disconnect_and_cleanup fails because the message stays
as "Response was terminated prior to completion, try regenerating."
"""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3)
setup = _make_setup(n_models=2)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Must be called once per model, not zero times
assert mock_handle.call_count == 2
def test_completion_handle_called_for_each_successful_model(self) -> None:
"""llm_loop_completion_handle must be called once per model that succeeded."""
setup = _make_setup(n_models=2)
with (
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
assert mock_handle.call_count == 2
def test_completion_handle_not_called_for_failed_model(self) -> None:
"""llm_loop_completion_handle must be skipped for a model that raised."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("fail")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(_make_setup(n_models=1))
mock_handle.assert_not_called()
def test_http_disconnect_completion_via_generator_exit(self) -> None:
"""GeneratorExit from HTTP disconnect triggers worker self-completion.
When the HTTP client closes the connection, Starlette throws GeneratorExit
into the stream generator. The finally block sets drain_done (signalling
emitters to stop blocking) and calls executor.shutdown(wait=False) so the
server thread is never blocked. Worker threads detect drain_done.is_set()
after run_llm_loop completes and self-persist the result via
llm_loop_completion_handle using their own DB session.
This is the primary regression for test_send_message_disconnect_and_cleanup:
the integration test disconnects mid-stream and expects the DB message to be
updated from the TERMINATED placeholder to the real response.
"""
import threading
# Signals the worker to unblock from run_llm_loop after gen.close() returns.
# This guarantees drain_done is set BEFORE the worker returns from run_llm_loop,
# so the self-completion path (drain_done.is_set() check) is always taken.
disconnect_received = threading.Event()
# Set by the llm_loop_completion_handle mock when called.
completion_called = threading.Event()
def emit_then_complete(**kwargs: Any) -> None:
"""Emit one packet (to give the drain loop a yield point), then block
until the main thread signals that gen.close() has been called. This
ensures drain_done is set before we return so model_succeeded is checked
against a set drain_done — no race condition.
"""
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
disconnect_received.wait(timeout=5)
setup = _make_setup(n_models=1)
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_then_complete,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
# cast to Generator so .close() is available; _run_models returns
# AnswerStream (= Iterator) but the actual object is always a generator.
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
# Advance to the first yielded packet — generator suspends at `yield item`.
first = next(gen)
assert isinstance(first, Packet)
# Simulate Starlette closing the stream on HTTP client disconnect.
# GeneratorExit is thrown at the `yield item` suspension point.
gen.close()
# Unblock the worker now that drain_done has been set by gen.close().
disconnect_received.set()
# Worker self-completes asynchronously (executor.shutdown(wait=False)).
# Wait here, inside the patch context, so that get_session_with_current_tenant
# and llm_loop_completion_handle mocks are still active when the worker calls them.
assert completion_called.wait(
timeout=5
), "worker must self-complete via drain_done within 5 seconds"
assert (
mock_handle.call_count == 1
), "completion handle must be called once for the successful model"
def test_external_state_container_used_for_model_zero(self) -> None:
"""When provided, external_state_container is used as state_containers[0]."""
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.process_message import _run_models
external = ChatStateContainer()
setup = _make_setup(n_models=1)
with (
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
list(
_run_models(
setup, MagicMock(), MagicMock(), external_state_container=external
)
)
# The state_container kwarg passed to run_llm_loop must be the external one
call_kwargs = mock_llm.call_args.kwargs
assert call_kwargs["state_container"] is external

View File

@@ -1,5 +1,3 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -33,7 +31,6 @@ def mock_jira_cc_pair(
"jira_base_url": jira_base_url,
"project_key": project_key,
}
mock_cc_pair.connector.indexing_start = None
return mock_cc_pair
@@ -68,75 +65,3 @@ def test_jira_permission_sync(
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
):
print(doc)
def test_jira_doc_sync_passes_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that generic_doc_sync derives indexing_start from cc_pair
and forwards it to retrieve_all_slim_docs_perm_sync."""
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
def test_jira_doc_sync_passes_none_when_no_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that indexing_start is None when the connector has no indexing_start set."""
mock_jira_cc_pair.connector.indexing_start = None
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] is None

View File

@@ -0,0 +1,225 @@
"""Tests for get_chat_sessions_by_user filtering behavior.
Verifies that failed chat sessions (those with only SYSTEM messages) are
correctly filtered out while preserving recently created sessions, matching
the behavior specified in PR #7233.
"""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from uuid import UUID
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.models import ChatSession
def _make_session(
user_id: UUID,
time_created: datetime | None = None,
time_updated: datetime | None = None,
description: str = "",
) -> MagicMock:
"""Create a mock ChatSession with the given attributes."""
session = MagicMock(spec=ChatSession)
session.id = uuid4()
session.user_id = user_id
session.time_created = time_created or datetime.now(timezone.utc)
session.time_updated = time_updated or session.time_created
session.description = description
session.deleted = False
session.onyxbot_flow = False
session.project_id = None
return session
@pytest.fixture
def user_id() -> UUID:
return uuid4()
@pytest.fixture
def old_time() -> datetime:
"""A timestamp well outside the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(hours=1)
@pytest.fixture
def recent_time() -> datetime:
"""A timestamp within the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(minutes=2)
class TestGetChatSessionsByUser:
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
def test_filters_out_failed_sessions(
self, user_id: UUID, old_time: datetime
) -> None:
"""Sessions with only SYSTEM messages should be excluded."""
valid_session = _make_session(user_id, time_created=old_time)
failed_session = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
# First execute: returns all sessions
# Second execute: returns only the valid session's ID (has non-system msgs)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
valid_session,
failed_session,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == valid_session.id
def test_keeps_recent_sessions_without_messages(
self, user_id: UUID, recent_time: datetime
) -> None:
"""Recently created sessions should be kept even without messages."""
recent_session = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [recent_session]
db_session.execute.side_effect = [mock_result_1]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == recent_session.id
# Should only have been called once — no second query needed
# because the recent session is within the leeway window
assert db_session.execute.call_count == 1
def test_include_failed_chats_skips_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""When include_failed_chats=True, no filtering should occur."""
session_a = _make_session(user_id, time_created=old_time)
session_b = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=True,
)
assert len(result) == 2
# Only one DB call — no second query for message validation
assert db_session.execute.call_count == 1
def test_limit_applied_after_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""Limit should be applied after filtering, not before."""
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
valid_ids = [s.id for s in sessions[:3]]
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = sessions
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = valid_ids
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
limit=2,
)
assert len(result) == 2
# Should be the first 2 valid sessions (order preserved)
assert result[0].id == sessions[0].id
assert result[1].id == sessions[1].id
def test_mixed_recent_and_old_sessions(
self, user_id: UUID, old_time: datetime, recent_time: datetime
) -> None:
"""Mix of recent and old sessions should filter correctly."""
old_valid = _make_session(user_id, time_created=old_time)
old_failed = _make_session(user_id, time_created=old_time)
recent_no_msgs = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
old_valid,
old_failed,
recent_no_msgs,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
result_ids = {cs.id for cs in result}
assert old_valid.id in result_ids
assert recent_no_msgs.id in result_ids
assert old_failed.id not in result_ids
def test_empty_result(self, user_id: UUID) -> None:
"""No sessions should return empty list without errors."""
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert result == []
assert db_session.execute.call_count == 1

View File

@@ -1,223 +0,0 @@
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int,
) -> DocMetadataAwareIndexChunk:
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
doc = Document(
id=doc_id,
sections=[TextSection(text="test", link="http://test.com")],
source=DocumentSource.FILE,
semantic_identifier="test_doc",
metadata={},
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links={0: "http://test.com"},
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(full_embedding=[0.1] * 10, mini_chunk_embeddings=[]),
title_embedding=[0.1] * 10,
tenant_id="test_tenant",
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
ancestor_hierarchy_node_ids=[],
)
def _make_index() -> tuple[OpenSearchDocumentIndex, MagicMock]:
"""Creates an OpenSearchDocumentIndex with a mocked client.
Returns the index and the mock for bulk_index_documents."""
mock_client = MagicMock()
mock_bulk = MagicMock()
mock_client.bulk_index_documents = mock_bulk
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
index._index_name = "test_index"
index._client = mock_client
index._tenant_state = tenant_state
return index, mock_bulk
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=0,
new_chunk_cnt=chunk_count,
),
},
)
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_under_batch_limit_flushes_once() -> None:
"""A document with fewer chunks than MAX_CHUNKS_PER_DOC_BATCH should flush once."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 50
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
assert mock_bulk.call_count == 1
batch_arg = mock_bulk.call_args_list[0]
assert len(batch_arg.kwargs["documents"]) == num_chunks
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
"""A document with more chunks than MAX_CHUNKS_PER_DOC_BATCH should flush multiple times."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 250
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
assert mock_bulk.call_count == 3
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
assert batch_sizes == [100, 100, 50]
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_exactly_at_batch_limit() -> None:
"""A document with exactly MAX_CHUNKS_PER_DOC_BATCH chunks should flush once
(the flush happens on the next chunk, not at the boundary)."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 100
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
# so final flush handles all 100
# Actually: the elif fires when len(current_chunks) >= 100, which happens
# when current_chunks has 100 items and the 101st chunk arrives.
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
# No 101st chunk arrives, so the final flush handles all 100.
assert mock_bulk.call_count == 1
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_single_doc_one_over_batch_limit() -> None:
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
the 101st is flushed at the end."""
index, mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 101
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
assert mock_bulk.call_count == 2
batch_sizes = [len(call.kwargs["documents"]) for call in mock_bulk.call_args_list]
assert batch_sizes == [100, 1]
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
"""Multiple documents each under the batch limit should flush once per document."""
index, mock_bulk = _make_index()
chunks = []
for doc_idx in range(3):
doc_id = f"doc_{doc_idx}"
for chunk_idx in range(50):
chunks.append(_make_chunk(doc_id, chunk_idx))
metadata = IndexingMetadata(
doc_id_to_chunk_cnt_diff={
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
for i in range(3)
},
)
with patch.object(index, "delete", return_value=0):
index.index(chunks, metadata)
# 3 documents = 3 flushes (one per doc boundary + final)
assert mock_bulk.call_count == 3
@patch(
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
100,
)
def test_delete_called_once_per_document() -> None:
"""Even with multiple flushes for a single document, delete should only be
called once per document."""
index, _mock_bulk = _make_index()
doc_id = "doc_1"
num_chunks = 250
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
metadata = _make_metadata(doc_id, num_chunks)
with patch.object(index, "delete", return_value=0) as mock_delete:
index.index(chunks, metadata)
mock_delete.assert_called_once_with(doc_id, None)

View File

@@ -1,152 +0,0 @@
"""Unit tests for VespaDocumentIndex.index().
These tests mock all external I/O (HTTP calls, thread pools) and verify
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
construction.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
def _make_chunk(
doc_id: str,
chunk_id: int = 0,
content: str = "test content",
) -> DocMetadataAwareIndexChunk:
doc = Document(
id=doc_id,
semantic_identifier="test_doc",
sections=[TextSection(text=content, link=None)],
source=DocumentSource.NOT_APPLICABLE,
metadata={},
)
index_chunk = IndexChunk(
chunk_id=chunk_id,
blurb=content[:50],
content=content,
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
embeddings=ChunkEmbedding(
full_embedding=[0.1] * 10,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=index_chunk,
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
tenant_id="test_tenant",
)
def _make_indexing_metadata(
doc_ids: list[str],
old_counts: list[int],
new_counts: list[int],
) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=old,
new_chunk_cnt=new,
)
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
}
)
def _stub_enrich(
doc_id: str,
old_chunk_cnt: int,
) -> EnrichedDocumentIndexingInfo:
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
return EnrichedDocumentIndexingInfo(
doc_id=doc_id,
chunk_start_index=0,
old_version=False,
chunk_end_index=old_chunk_cnt,
)
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
@patch(
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
return_value=[],
)
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
@patch(
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
3,
)
def test_index_respects_batch_size(
mock_enrich: MagicMock,
mock_get_chunk_ids: MagicMock, # noqa: ARG001
mock_delete: MagicMock, # noqa: ARG001
mock_batch_index: MagicMock,
) -> None:
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
multiple times with correctly sized batches."""
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
index = VespaDocumentIndex(
index_name="test_index",
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
large_chunks_enabled=False,
httpx_client=MagicMock(),
)
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
results = index.index(chunks=chunks, indexing_metadata=metadata)
assert len(results) == 1
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
assert mock_batch_index.call_count == 3
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
assert batch_sizes == [3, 3, 1]
# Verify all chunks are accounted for and in order
all_indexed = [
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
]
assert len(all_indexed) == 7
assert [c.chunk_id for c in all_indexed] == list(range(7))

View File

@@ -1,391 +0,0 @@
"""Unit tests for _embed_chunks_to_store.
Tests cover:
- Single batch, no failures
- Multiple batches, no failures
- Failure in a single batch
- Cross-batch document failure scrubbing
- Later batches skip already-failed docs
- Empty input
- All chunks fail
"""
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import TextSection
from onyx.indexing.chunk_batch_store import ChunkBatchStore
from onyx.indexing.indexing_pipeline import _embed_chunks_to_store
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexChunk
def _make_doc(doc_id: str) -> Document:
return Document(
id=doc_id,
semantic_identifier="test",
source=DocumentSource.FILE,
sections=[TextSection(text="test", link=None)],
metadata={},
)
def _make_chunk(doc_id: str, chunk_id: int) -> DocAwareChunk:
return DocAwareChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=_make_doc(doc_id),
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
)
def _make_index_chunk(doc_id: str, chunk_id: int) -> IndexChunk:
"""Create an IndexChunk (a DocAwareChunk with embeddings)."""
return IndexChunk(
chunk_id=chunk_id,
blurb="test",
content="test content",
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=_make_doc(doc_id),
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(
full_embedding=[0.1] * 10,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
def _make_failure(doc_id: str) -> ConnectorFailure:
return ConnectorFailure(
failed_document=DocumentFailure(document_id=doc_id, document_link=None),
failure_message="embedding failed",
exception=RuntimeError("embedding failed"),
)
def _mock_embed_success(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
"""Simulate successful embedding of all chunks."""
return (
[_make_index_chunk(c.source_document.id, c.chunk_id) for c in chunks],
[],
)
def _mock_embed_fail_doc(
fail_doc_id: str,
) -> Callable[..., tuple[list[IndexChunk], list[ConnectorFailure]]]:
"""Return an embed mock that fails all chunks for a specific doc."""
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
successes = [
_make_index_chunk(c.source_document.id, c.chunk_id)
for c in chunks
if c.source_document.id != fail_doc_id
]
failures = (
[_make_failure(fail_doc_id)]
if any(c.source_document.id == fail_doc_id for c in chunks)
else []
)
return successes, failures
return _embed
class TestEmbedChunksInBatches:
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_single_batch_no_failures(self, mock_embed: MagicMock) -> None:
"""All chunks fit in one batch and embed successfully."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", i) for i in range(3)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 3
assert len(result.connector_failures) == 0
# Verify stored contents
assert len(store._batch_files()) == 1
stored = list(store.stream())
assert len(stored) == 3
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_multiple_batches_no_failures(self, mock_embed: MagicMock) -> None:
"""Chunks are split across multiple batches, all succeed."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", i) for i in range(7)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 7
assert len(result.connector_failures) == 0
assert len(store._batch_files()) == 3 # 3 + 3 + 1
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_single_batch_with_failure(self, mock_embed: MagicMock) -> None:
"""One doc fails embedding, its chunks are excluded from results."""
mock_embed.side_effect = _mock_embed_fail_doc("doc2")
with ChunkBatchStore() as store:
chunks = [
_make_chunk("doc1", 0),
_make_chunk("doc2", 1),
_make_chunk("doc1", 2),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.connector_failures) == 1
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "doc2" not in successful_doc_ids
assert "doc1" in successful_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_cross_batch_failure_scrubs_earlier_batch(
self, mock_embed: MagicMock
) -> None:
"""Doc A spans batches 0 and 1. It succeeds in batch 0 but fails in
batch 1. Its chunks should be scrubbed from batch 0's batch file."""
call_count = 0
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
nonlocal call_count
call_count += 1
if call_count == 1:
return _mock_embed_success(chunks)
else:
return _mock_embed_fail_doc("docA")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("docA", 0),
_make_chunk("docA", 1),
_make_chunk("docA", 2),
_make_chunk("docA", 3),
_make_chunk("docB", 0),
_make_chunk("docB", 1),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# docA should be fully excluded from results
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "docA" not in successful_doc_ids
assert "docB" in successful_doc_ids
assert len(result.connector_failures) == 1
# Verify batch 0 was scrubbed of docA chunks
all_stored = list(store.stream())
stored_doc_ids = {c.source_document.id for c in all_stored}
assert "docA" not in stored_doc_ids
assert "docB" in stored_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_later_batch_skips_already_failed_doc(self, mock_embed: MagicMock) -> None:
"""If docA fails in batch 0, its chunks in batch 1 are skipped
entirely (never sent to the embedder)."""
embedded_doc_ids: list[str] = []
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
for c in chunks:
embedded_doc_ids.append(c.source_document.id)
return _mock_embed_fail_doc("docA")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("docA", 0),
_make_chunk("docA", 1),
_make_chunk("docA", 2),
_make_chunk("docA", 3),
_make_chunk("docB", 0),
_make_chunk("docB", 1),
]
_embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# docA should only appear in batch 0, not batch 1
batch_1_doc_ids = embedded_doc_ids[3:]
assert "docA" not in batch_1_doc_ids
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
def test_failed_doc_skipped_in_later_batch_while_other_doc_succeeds(
self, mock_embed: MagicMock
) -> None:
"""doc1 spans batches 0 and 1, doc2 only in batch 1. Batch 0 fails
doc1. In batch 1, doc1 chunks should be skipped but doc2 chunks
should still be embedded successfully."""
embedded_chunks: list[list[str]] = []
def _embed(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
embedded_chunks.append([c.source_document.id for c in chunks])
return _mock_embed_fail_doc("doc1")(chunks)
mock_embed.side_effect = _embed
with ChunkBatchStore() as store:
chunks = [
_make_chunk("doc1", 0),
_make_chunk("doc1", 1),
_make_chunk("doc1", 2),
_make_chunk("doc1", 3),
_make_chunk("doc2", 0),
_make_chunk("doc2", 1),
]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
# doc1 should be fully excluded, doc2 fully included
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
assert "doc1" not in successful_doc_ids
assert "doc2" in successful_doc_ids
assert len(result.successful_chunk_ids) == 2 # doc2's 2 chunks
# Batch 1 should only contain doc2 (doc1 was filtered before embedding)
assert len(embedded_chunks) == 2
assert "doc1" not in embedded_chunks[1]
assert embedded_chunks[1] == ["doc2", "doc2"]
# Verify on-disk state has no doc1 chunks
all_stored = list(store.stream())
assert all(c.source_document.id == "doc2" for c in all_stored)
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
def test_empty_input(self, mock_embed: MagicMock) -> None:
"""Empty chunk list produces empty results."""
mock_embed.side_effect = _mock_embed_success
with ChunkBatchStore() as store:
result = _embed_chunks_to_store(
chunks=[],
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 0
assert len(result.connector_failures) == 0
mock_embed.assert_not_called()
@patch(
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
)
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
def test_all_chunks_fail(self, mock_embed: MagicMock) -> None:
"""When all documents fail, results have no successful chunks."""
def _fail_all(
chunks: list[DocAwareChunk], **_kwargs: object
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
doc_ids = {c.source_document.id for c in chunks}
return [], [_make_failure(doc_id) for doc_id in doc_ids]
mock_embed.side_effect = _fail_all
with ChunkBatchStore() as store:
chunks = [_make_chunk("doc1", 0), _make_chunk("doc2", 1)]
result = _embed_chunks_to_store(
chunks=chunks,
embedder=MagicMock(),
tenant_id="test",
request_id=None,
store=store,
)
assert len(result.successful_chunk_ids) == 0
assert len(result.connector_failures) == 2

View File

@@ -116,7 +116,7 @@ def _run_adapter_build(
project_ids_map: dict[str, list[int]],
persona_ids_map: dict[str, list[int]],
) -> list[DocMetadataAwareIndexChunk]:
"""Helper that runs UserFileIndexingAdapter.prepare_enrichment + enrich_chunk
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
with all external dependencies mocked."""
from onyx.indexing.adapters.user_file_indexing_adapter import (
UserFileIndexingAdapter,
@@ -155,16 +155,18 @@ def _run_adapter_build(
side_effect=Exception("no LLM in tests"),
),
):
enricher = adapter.prepare_enrichment(
context=context,
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id="test_tenant",
chunks=[chunk],
context=context,
)
return [enricher.enrich_chunk(chunk, 1.0)]
return result.chunks
def test_prepare_enrichment_includes_persona_ids() -> None:
"""UserFileIndexingAdapter.prepare_enrichment writes persona IDs
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
fetched from the DB into each chunk's metadata."""
file_id = str(uuid4())
persona_ids = [5, 12]
@@ -181,7 +183,7 @@ def test_prepare_enrichment_includes_persona_ids() -> None:
assert chunks[0].user_project == project_ids
def test_prepare_enrichment_missing_file_defaults_to_empty() -> None:
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
"""When a file has no persona or project associations in the DB, the
adapter should default to empty lists (not KeyError or None)."""
file_id = str(uuid4())

View File

@@ -4,23 +4,13 @@ from unittest.mock import MagicMock
import pytest
from fastapi import UploadFile
from onyx.natural_language_processing import utils as nlp_utils
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.server.features.projects import projects_file_utils as utils
from onyx.server.settings.models import Settings
class _Tokenizer(BaseTokenizer):
class _Tokenizer:
def encode(self, text: str) -> list[int]:
return [1] * len(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
class _NonSeekableFile(BytesIO):
def tell(self) -> int:
@@ -39,26 +29,10 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
return UploadFile(filename=filename, file=BytesIO(content), size=None)
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
return Settings(
user_file_max_upload_size_mb=upload_size_mb,
file_token_count_threshold_k=token_threshold_k,
)
def _patch_common_dependencies(
monkeypatch: pytest.MonkeyPatch,
upload_size_mb: int = 1,
token_threshold_k: int = 100,
) -> None:
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
monkeypatch.setattr(
utils,
"load_settings",
lambda: _make_settings(upload_size_mb, token_threshold_k),
)
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
@@ -102,8 +76,9 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
def test_categorize_uploaded_files_accepts_size_under_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("small.png", size=99)
@@ -116,7 +91,9 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload_no_size("small.png", content=b"x" * 99)
@@ -129,11 +106,12 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
def test_categorize_uploaded_files_accepts_size_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
upload = _make_upload("edge.png", size=1048576)
upload = _make_upload("edge.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
@@ -143,10 +121,12 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
upload = _make_upload("large.png", size=101)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -157,11 +137,13 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
small = _make_upload("small.png", size=50)
large = _make_upload("large.png", size=1048577)
large = _make_upload("large.png", size=101)
result = utils.categorize_uploaded_files([small, large], MagicMock())
@@ -171,12 +153,15 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_enforces_size_limit_always(
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
upload = _make_upload("oversized.pdf", size=1048577)
upload = _make_upload("oversized.pdf", size=101)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -187,12 +172,14 @@ def test_categorize_uploaded_files_enforces_size_limit_always(
def test_categorize_uploaded_files_checks_size_before_text_extraction(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
extract_mock = MagicMock(return_value="this should not run")
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
oversized_doc = _make_upload("oversized.pdf", size=1048577)
oversized_doc = _make_upload("oversized.pdf", size=101)
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
extract_mock.assert_not_called()
@@ -201,219 +188,40 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
def test_categorize_uploaded_files_accepts_python_file(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A positive upload_size_mb is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
upload = _make_upload("huge.png", size=1048577, content=b"x")
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A positive token_threshold_k is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
upload = _make_upload("big_image.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
def test_categorize_no_token_limit_when_threshold_k_is_zero(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
py_source = b'def hello():\n print("world")\n'
monkeypatch.setattr(
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
)
upload = _make_upload("huge_image.png", size=100)
upload = _make_upload("script.py", size=len(py_source), content=py_source)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 0
assert len(result.acceptable) == 1
assert result.acceptable[0].filename == "script.py"
assert len(result.rejected) == 0
def test_categorize_both_limits_enforced(
def test_categorize_uploaded_files_rejects_binary_file(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Both positive limits are enforced; file exceeding token limit is rejected."""
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
upload = _make_upload("over_tokens.png", size=100)
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
binary_content = bytes(range(256)) * 4
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 5K token limit"
def test_categorize_rejection_reason_contains_dynamic_values(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
# File within size limit but over token limit
upload = _make_upload("tokens.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert result.rejected[0].reason == "Exceeds 7K token limit"
# File over size limit
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
# --- count_tokens tests ---
def test_count_tokens_small_text() -> None:
"""Small text should be encoded in a single call and return correct count."""
tokenizer = _Tokenizer()
text = "hello world"
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_chunked_matches_single_call() -> None:
"""Chunked encoding should produce the same result as single-call for small text."""
tokenizer = _Tokenizer()
text = "a" * 1000
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
# _Tokenizer returns 1 token per char, so total should be 250
assert count_tokens(text, tokenizer) == 250
def test_count_tokens_with_token_limit_exits_early(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set and exceeded, count_tokens should stop early."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
tokenizer = _CountingTokenizer()
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
text = "a" * 500
result = count_tokens(text, tokenizer, token_limit=150)
assert result == 200 # 2 chunks × 100 tokens each
assert encode_call_count == 2, "Should have stopped after 2 chunks"
def test_count_tokens_with_token_limit_not_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set but not exceeded, all chunks are encoded."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
result = count_tokens(text, tokenizer, token_limit=1000)
assert result == 250
def test_count_tokens_no_limit_encodes_all_chunks(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Without token_limit, all chunks are encoded regardless of count."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 500
result = count_tokens(text, tokenizer)
assert result == 500
# --- early exit via token_limit in categorize tests ---
def test_categorize_early_exits_tokenization_for_large_text(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Large text files should be rejected via early-exit tokenization
without encoding all chunks."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
# Let's use a bigger text
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
large_text = "x" * 5000 # 5000 tokens, threshold 1000
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 1
assert "token limit" in result.rejected[0].reason
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
assert (
encode_call_count < 50
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
def test_categorize_text_under_token_limit_accepted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Text files under the token threshold should be accepted with exact count."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
small_text = "x" * 500 # 500 tokens < 1000 threshold
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable_file_to_token_count["ok.txt"] == 500
assert result.rejected[0].filename == "data.bin"
assert "Unsupported file type" in result.rejected[0].reason

View File

@@ -1,23 +1,12 @@
import pytest
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings import store as settings_store
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
class _FakeKvStore:
def __init__(self, data: dict | None = None) -> None:
self._data = data
def load(self, _key: str) -> dict:
if self._data is None:
raise KvKeyNotFoundError()
return self._data
raise KvKeyNotFoundError()
class _FakeCache:
@@ -31,140 +20,13 @@ class _FakeCache:
self._vals[key] = value.encode("utf-8")
def test_load_settings_uses_model_defaults_when_no_stored_value(
def test_load_settings_includes_user_file_max_upload_size_mb(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When no settings are stored (vector DB enabled), load_settings() should
resolve the default token threshold to 200."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled and no settings are stored, the token
threshold should default to 10000 (10M tokens)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
)
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled but admin explicitly set a token threshold,
that value should be preserved (not overridden by the 10000 default)."""
stored = Settings(file_token_count_threshold_k=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 500
def test_load_settings_preserves_zero_token_threshold(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 means 'no limit' and should be preserved."""
stored = Settings(file_token_count_threshold_k=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 0
def test_load_settings_resolves_zero_upload_size_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
def test_load_settings_clamps_upload_size_to_env_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be clamped to the env-configured maximum."""
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 250
def test_load_settings_preserves_upload_size_within_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be preserved unchanged."""
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 150
def test_load_settings_zero_upload_size_resolves_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default,
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 100
def test_load_settings_default_clamped_to_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
the effective default should be min(DEFAULT, MAX)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 50
assert settings.user_file_max_upload_size_mb == 77

View File

@@ -1,6 +1,6 @@
"""Tests for memory tool streaming packet emissions."""
import queue
from queue import Queue
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -18,13 +18,9 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
@pytest.fixture
def emitter_queue() -> queue.Queue:
return queue.Queue()
@pytest.fixture
def emitter(emitter_queue: queue.Queue) -> Emitter:
return Emitter(merged_queue=emitter_queue)
def emitter() -> Emitter:
bus: Queue = Queue()
return Emitter(bus)
@pytest.fixture
@@ -57,27 +53,24 @@ class TestMemoryToolEmitStart:
def test_emit_start_emits_memory_tool_start_packet(
self,
memory_tool: MemoryTool,
emitter_queue: queue.Queue,
emitter: Emitter,
placement: Placement,
) -> None:
memory_tool.emit_start(placement)
_key, packet = emitter_queue.get_nowait()
packet = emitter.bus.get_nowait()
assert isinstance(packet.obj, MemoryToolStart)
assert packet.placement is not None
assert packet.placement.turn_index == placement.turn_index
assert packet.placement.tab_index == placement.tab_index
assert packet.placement.model_index == 0 # emitter stamps model_index=0
assert packet.placement == placement
def test_emit_start_with_different_placement(
self,
memory_tool: MemoryTool,
emitter_queue: queue.Queue,
emitter: Emitter,
) -> None:
placement = Placement(turn_index=2, tab_index=1)
memory_tool.emit_start(placement)
_key, packet = emitter_queue.get_nowait()
packet = emitter.bus.get_nowait()
assert packet.placement.turn_index == 2
assert packet.placement.tab_index == 1
@@ -88,7 +81,7 @@ class TestMemoryToolRun:
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter_queue: queue.Queue,
emitter: Emitter,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -100,19 +93,21 @@ class TestMemoryToolRun:
memory="User prefers Python",
)
_key, packet = emitter_queue.get_nowait()
# The delta packet should be in the queue
packet = emitter.bus.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers Python"
assert packet.obj.operation == "add"
assert packet.obj.memory_id is None
assert packet.obj.index is None
assert packet.placement == placement
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
def test_run_emits_delta_for_update_operation(
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter_queue: queue.Queue,
emitter: Emitter,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -124,7 +119,7 @@ class TestMemoryToolRun:
memory="User prefers light mode",
)
_key, packet = emitter_queue.get_nowait()
packet = emitter.bus.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers light mode"
assert packet.obj.operation == "update"

1
cli/.gitignore vendored
View File

@@ -1,4 +1,3 @@
onyx-cli
cli
onyx.cli
__pycache__

View File

@@ -63,31 +63,6 @@ onyx-cli agents
onyx-cli agents --json
```
### Serve over SSH
```shell
# Start a public SSH endpoint for the CLI TUI
onyx-cli serve --host 0.0.0.0 --port 2222
# Connect as a client
ssh your-host -p 2222
```
Clients can either:
- paste an API key at the login prompt, or
- skip the prompt by sending `ONYX_API_KEY` over SSH:
```shell
export ONYX_API_KEY=your-key
ssh -o SendEnv=ONYX_API_KEY your-host -p 2222
```
Useful hardening flags:
- `--idle-timeout` (default `15m`)
- `--max-session-timeout` (default `8h`)
- `--rate-limit-per-minute` (default `20`)
- `--rate-limit-burst` (default `40`)
## Commands
| Command | Description |
@@ -95,7 +70,6 @@ Useful hardening flags:
| `chat` | Launch the interactive chat TUI (default) |
| `ask` | Ask a one-shot question (non-interactive) |
| `agents` | List available agents |
| `serve` | Serve the interactive chat TUI over SSH |
| `configure` | Configure server URL and API key |
| `validate-config` | Validate configuration and test connection |

View File

@@ -1,17 +1,7 @@
// Package cmd implements Cobra CLI commands for the Onyx CLI.
package cmd
import (
"context"
"fmt"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/version"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
import "github.com/spf13/cobra"
// Version and Commit are set via ldflags at build time.
var (
@@ -26,69 +16,15 @@ func fullVersion() string {
return Version
}
func printVersion(cmd *cobra.Command) {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Client version: %s\n", fullVersion())
cfg := config.Load()
if !cfg.IsConfigured() {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (not configured)\n")
return
}
client := api.NewClient(cfg)
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer cancel()
log.Debug("fetching backend version from /api/version")
backendVersion, err := client.GetBackendVersion(ctx)
if err != nil {
log.WithError(err).Debug("could not fetch backend version")
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (could not reach server)\n")
return
}
if backendVersion == "" {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (empty response)\n")
return
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: %s\n", backendVersion)
min := version.MinServer()
if sv, ok := version.Parse(backendVersion); ok && sv.LessThan(min) {
log.Warnf("Server version %s is below minimum required %d.%d, please upgrade",
backendVersion, min.Major, min.Minor)
}
}
// Execute creates and runs the root command.
func Execute() error {
opts := struct {
Debug bool
}{}
rootCmd := &cobra.Command{
Use: "onyx-cli",
Short: "Terminal UI for chatting with Onyx",
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
PersistentPreRun: func(cmd *cobra.Command, args []string) {
if opts.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.InfoLevel)
}
log.SetFormatter(&log.TextFormatter{
DisableTimestamp: true,
})
},
Use: "onyx-cli",
Short: "Terminal UI for chatting with Onyx",
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
Version: fullVersion(),
}
rootCmd.PersistentFlags().BoolVar(&opts.Debug, "debug", false, "run in debug mode")
// Custom --version flag instead of Cobra's built-in (which only shows one version string)
var showVersion bool
rootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "Print client and server version information")
// Register subcommands
chatCmd := newChatCmd()
rootCmd.AddCommand(chatCmd)
@@ -96,16 +32,9 @@ func Execute() error {
rootCmd.AddCommand(newAgentsCmd())
rootCmd.AddCommand(newConfigureCmd())
rootCmd.AddCommand(newValidateConfigCmd())
rootCmd.AddCommand(newServeCmd())
// Default command is chat, but intercept --version first
rootCmd.RunE = func(cmd *cobra.Command, args []string) error {
if showVersion {
printVersion(cmd)
return nil
}
return chatCmd.RunE(cmd, args)
}
// Default command is chat
rootCmd.RunE = chatCmd.RunE
return rootCmd.Execute()
}

View File

@@ -1,450 +0,0 @@
package cmd
import (
"context"
"errors"
"fmt"
"net"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/log"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
"github.com/charmbracelet/wish/activeterm"
"github.com/charmbracelet/wish/bubbletea"
"github.com/charmbracelet/wish/logging"
"github.com/charmbracelet/wish/ratelimiter"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/tui"
"github.com/spf13/cobra"
"golang.org/x/time/rate"
)
const (
defaultServeIdleTimeout = 15 * time.Minute
defaultServeMaxSessionTimeout = 8 * time.Hour
defaultServeRateLimitPerMinute = 20
defaultServeRateLimitBurst = 40
defaultServeRateLimitCacheSize = 4096
maxAPIKeyLength = 512
apiKeyValidationTimeout = 15 * time.Second
maxAPIKeyRetries = 5
)
func sessionEnv(s ssh.Session, key string) string {
prefix := key + "="
for _, env := range s.Environ() {
if strings.HasPrefix(env, prefix) {
return env[len(prefix):]
}
}
return ""
}
func validateAPIKey(serverURL string, apiKey string) error {
trimmedKey := strings.TrimSpace(apiKey)
if len(trimmedKey) > maxAPIKeyLength {
return fmt.Errorf("API key is too long (max %d characters)", maxAPIKeyLength)
}
cfg := config.OnyxCliConfig{
ServerURL: serverURL,
APIKey: trimmedKey,
}
client := api.NewClient(cfg)
ctx, cancel := context.WithTimeout(context.Background(), apiKeyValidationTimeout)
defer cancel()
return client.TestConnection(ctx)
}
// --- auth prompt (bubbletea model) ---
type authState int
const (
authInput authState = iota
authValidating
authDone
)
type authValidatedMsg struct {
key string
err error
}
type authModel struct {
input textinput.Model
serverURL string
state authState
apiKey string // set on successful validation
errMsg string
retries int
aborted bool
}
func newAuthModel(serverURL, initialErr string) authModel {
ti := textinput.New()
ti.Prompt = " API Key: "
ti.EchoMode = textinput.EchoPassword
ti.EchoCharacter = '•'
ti.CharLimit = maxAPIKeyLength
ti.Width = 80
ti.Focus()
return authModel{
input: ti,
serverURL: serverURL,
errMsg: initialErr,
}
}
func (m authModel) Update(msg tea.Msg) (authModel, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.input.Width = max(msg.Width-14, 20) // account for prompt width
return m, nil
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyCtrlD:
m.aborted = true
return m, nil
default:
if m.state == authValidating {
return m, nil
}
}
switch msg.Type {
case tea.KeyEnter:
key := strings.TrimSpace(m.input.Value())
if key == "" {
m.errMsg = "No key entered."
m.retries++
if m.retries >= maxAPIKeyRetries {
m.errMsg = "Too many failed attempts. Disconnecting."
m.aborted = true
return m, nil
}
m.input.SetValue("")
return m, nil
}
m.state = authValidating
m.errMsg = ""
serverURL := m.serverURL
return m, func() tea.Msg {
return authValidatedMsg{key: key, err: validateAPIKey(serverURL, key)}
}
}
case authValidatedMsg:
if msg.err != nil {
m.state = authInput
m.errMsg = msg.err.Error()
m.retries++
if m.retries >= maxAPIKeyRetries {
m.errMsg = "Too many failed attempts. Disconnecting."
m.aborted = true
return m, nil
}
m.input.SetValue("")
return m, m.input.Focus()
}
m.apiKey = msg.key
m.state = authDone
return m, nil
}
if m.state == authInput {
var cmd tea.Cmd
m.input, cmd = m.input.Update(msg)
return m, cmd
}
return m, nil
}
func (m authModel) View() string {
settingsURL := strings.TrimRight(m.serverURL, "/") + "/app/settings/accounts-access"
var b strings.Builder
b.WriteString("\n")
b.WriteString(" \x1b[1;35mOnyx CLI\x1b[0m\n")
b.WriteString(" \x1b[90m" + m.serverURL + "\x1b[0m\n")
b.WriteString("\n")
b.WriteString(" Generate an API key at:\n")
b.WriteString(" \x1b[4;34m" + settingsURL + "\x1b[0m\n")
b.WriteString("\n")
b.WriteString(" \x1b[90mTip: skip this prompt by passing your key via SSH:\x1b[0m\n")
b.WriteString(" \x1b[90m export ONYX_API_KEY=<key>\x1b[0m\n")
b.WriteString(" \x1b[90m ssh -o SendEnv=ONYX_API_KEY <host> -p <port>\x1b[0m\n")
b.WriteString("\n")
if m.errMsg != "" {
b.WriteString(" \x1b[1;31m" + m.errMsg + "\x1b[0m\n\n")
}
switch m.state {
case authDone:
b.WriteString(" \x1b[32mAuthenticated.\x1b[0m\n")
case authValidating:
b.WriteString(" \x1b[90mValidating…\x1b[0m\n")
default:
b.WriteString(m.input.View() + "\n")
}
return b.String()
}
// --- serve model (wraps auth → TUI in a single bubbletea program) ---
type serveModel struct {
auth authModel
tui tea.Model
authed bool
serverCfg config.OnyxCliConfig
width int
height int
}
func newServeModel(serverCfg config.OnyxCliConfig, initialErr string) serveModel {
return serveModel{
auth: newAuthModel(serverCfg.ServerURL, initialErr),
serverCfg: serverCfg,
}
}
func (m serveModel) Init() tea.Cmd {
return textinput.Blink
}
func (m serveModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if !m.authed {
if ws, ok := msg.(tea.WindowSizeMsg); ok {
m.width = ws.Width
m.height = ws.Height
}
var cmd tea.Cmd
m.auth, cmd = m.auth.Update(msg)
if m.auth.aborted {
return m, tea.Quit
}
if m.auth.apiKey != "" {
cfg := config.OnyxCliConfig{
ServerURL: m.serverCfg.ServerURL,
APIKey: m.auth.apiKey,
DefaultAgentID: m.serverCfg.DefaultAgentID,
}
m.tui = tui.NewModel(cfg)
m.authed = true
w, h := m.width, m.height
return m, tea.Batch(
tea.EnterAltScreen,
tea.EnableMouseCellMotion,
m.tui.Init(),
func() tea.Msg { return tea.WindowSizeMsg{Width: w, Height: h} },
)
}
return m, cmd
}
var cmd tea.Cmd
m.tui, cmd = m.tui.Update(msg)
return m, cmd
}
func (m serveModel) View() string {
if !m.authed {
return m.auth.View()
}
return m.tui.View()
}
// --- serve command ---
func newServeCmd() *cobra.Command {
var (
host string
port int
keyPath string
idleTimeout time.Duration
maxSessionTimeout time.Duration
rateLimitPerMin int
rateLimitBurst int
rateLimitCache int
)
cmd := &cobra.Command{
Use: "serve",
Short: "Serve the Onyx TUI over SSH",
Long: `Start an SSH server that presents the interactive Onyx chat TUI to
connecting clients. Each SSH session gets its own independent TUI instance.
Clients are prompted for their Onyx API key on connect. The key can also be
provided via the ONYX_API_KEY environment variable to skip the prompt:
ssh -o SendEnv=ONYX_API_KEY host -p port
The server URL is taken from the server operator's config. The server
auto-generates an Ed25519 host key on first run if the key file does not
already exist. The host key path can also be set via the ONYX_SSH_HOST_KEY
environment variable (the --host-key flag takes precedence).
Example:
onyx-cli serve --port 2222
ssh localhost -p 2222`,
RunE: func(cmd *cobra.Command, args []string) error {
serverCfg := config.Load()
if serverCfg.ServerURL == "" {
return fmt.Errorf("server URL is not configured; run 'onyx-cli configure' first")
}
if !cmd.Flags().Changed("host-key") {
if v := os.Getenv(config.EnvSSHHostKey); v != "" {
keyPath = v
}
}
if rateLimitPerMin <= 0 {
return fmt.Errorf("--rate-limit-per-minute must be > 0")
}
if rateLimitBurst <= 0 {
return fmt.Errorf("--rate-limit-burst must be > 0")
}
if rateLimitCache <= 0 {
return fmt.Errorf("--rate-limit-cache must be > 0")
}
addr := net.JoinHostPort(host, fmt.Sprintf("%d", port))
connectionLimiter := ratelimiter.NewRateLimiter(
rate.Limit(float64(rateLimitPerMin)/60.0),
rateLimitBurst,
rateLimitCache,
)
handler := func(s ssh.Session) (tea.Model, []tea.ProgramOption) {
apiKey := strings.TrimSpace(sessionEnv(s, config.EnvAPIKey))
var envErr string
if apiKey != "" {
if err := validateAPIKey(serverCfg.ServerURL, apiKey); err != nil {
envErr = fmt.Sprintf("ONYX_API_KEY from SSH environment is invalid: %s", err.Error())
apiKey = ""
}
}
if apiKey != "" {
// Env key is valid — go straight to the TUI.
cfg := config.OnyxCliConfig{
ServerURL: serverCfg.ServerURL,
APIKey: apiKey,
DefaultAgentID: serverCfg.DefaultAgentID,
}
return tui.NewModel(cfg), []tea.ProgramOption{
tea.WithAltScreen(),
tea.WithMouseCellMotion(),
}
}
// No valid env key — show auth prompt, then transition
// to the TUI within the same bubbletea program.
return newServeModel(serverCfg, envErr), []tea.ProgramOption{
tea.WithMouseCellMotion(),
}
}
serverOptions := []ssh.Option{
wish.WithAddress(addr),
wish.WithHostKeyPath(keyPath),
wish.WithMiddleware(
bubbletea.Middleware(handler),
activeterm.Middleware(),
ratelimiter.Middleware(connectionLimiter),
logging.Middleware(),
),
}
if idleTimeout > 0 {
serverOptions = append(serverOptions, wish.WithIdleTimeout(idleTimeout))
}
if maxSessionTimeout > 0 {
serverOptions = append(serverOptions, wish.WithMaxTimeout(maxSessionTimeout))
}
s, err := wish.NewServer(serverOptions...)
if err != nil {
return fmt.Errorf("could not create SSH server: %w", err)
}
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGTERM)
log.Info("Starting Onyx SSH server", "addr", addr)
log.Info("Connect with", "cmd", fmt.Sprintf("ssh %s -p %d", host, port))
errCh := make(chan error, 1)
go func() {
if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Error("SSH server failed", "error", err)
errCh <- err
}
}()
var serverErr error
select {
case <-done:
case serverErr = <-errCh:
}
signal.Stop(done)
log.Info("Shutting down SSH server")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if shutdownErr := s.Shutdown(ctx); shutdownErr != nil {
return errors.Join(serverErr, shutdownErr)
}
return serverErr
},
}
cmd.Flags().StringVar(&host, "host", "localhost", "Host address to bind to")
cmd.Flags().IntVarP(&port, "port", "p", 2222, "Port to listen on")
cmd.Flags().StringVar(&keyPath, "host-key", filepath.Join(config.ConfigDir(), "host_ed25519"),
"Path to SSH host key (auto-generated if missing)")
cmd.Flags().DurationVar(
&idleTimeout,
"idle-timeout",
defaultServeIdleTimeout,
"Disconnect idle clients after this duration (set 0 to disable)",
)
cmd.Flags().DurationVar(
&maxSessionTimeout,
"max-session-timeout",
defaultServeMaxSessionTimeout,
"Maximum lifetime of a client session (set 0 to disable)",
)
cmd.Flags().IntVar(
&rateLimitPerMin,
"rate-limit-per-minute",
defaultServeRateLimitPerMinute,
"Per-IP connection rate limit (new sessions per minute)",
)
cmd.Flags().IntVar(
&rateLimitBurst,
"rate-limit-burst",
defaultServeRateLimitBurst,
"Per-IP burst limit for connection attempts",
)
cmd.Flags().IntVar(
&rateLimitCache,
"rate-limit-cache",
defaultServeRateLimitCacheSize,
"Maximum number of IP limiter entries tracked in memory",
)
return cmd
}

View File

@@ -1,14 +1,10 @@
package cmd
import (
"context"
"fmt"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/version"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
@@ -39,25 +35,6 @@ func newValidateConfigCmd() *cobra.Command {
}
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
// Check backend version compatibility
vCtx, vCancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer vCancel()
backendVersion, err := client.GetBackendVersion(vCtx)
if err != nil {
log.WithError(err).Debug("could not fetch backend version")
} else if backendVersion == "" {
log.Debug("server returned empty version string")
} else {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Version: %s\n", backendVersion)
min := version.MinServer()
if sv, ok := version.Parse(backendVersion); ok && sv.LessThan(min) {
log.Warnf("Server version %s is below minimum required %d.%d, please upgrade",
backendVersion, min.Major, min.Minor)
}
}
return nil
},
}

View File

@@ -1,63 +1,45 @@
module github.com/onyx-dot-app/onyx/cli
go 1.26.1
go 1.26.0
require (
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/glamour v1.0.0
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
github.com/charmbracelet/log v1.0.0
github.com/charmbracelet/ssh v0.0.0-20250826160808-ebfa259c7309
github.com/charmbracelet/wish v1.4.7
github.com/sirupsen/logrus v1.9.4
github.com/spf13/cobra v1.10.2
golang.org/x/term v0.41.0
golang.org/x/text v0.35.0
golang.org/x/time v0.15.0
github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.3.4
github.com/charmbracelet/glamour v0.8.0
github.com/charmbracelet/lipgloss v1.1.0
github.com/spf13/cobra v1.9.1
golang.org/x/term v0.30.0
golang.org/x/text v0.34.0
)
require (
github.com/alecthomas/chroma/v2 v2.23.1 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/charmbracelet/colorprofile v0.4.3 // indirect
github.com/charmbracelet/keygen v0.5.4 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/conpty v0.2.0 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca // indirect
github.com/charmbracelet/x/input v0.3.7 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.11.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/creack/pty v1.1.24 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.8.0 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/go-logfmt/logfmt v0.6.1 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.21 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yuin/goldmark v1.8.2 // indirect
github.com/yuin/goldmark-emoji v1.0.6 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sys v0.42.0 // indirect
github.com/yuin/goldmark v1.7.4 // indirect
github.com/yuin/goldmark-emoji v1.0.3 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.31.0 // indirect
)

View File

@@ -1,89 +1,55 @@
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.23.1 h1:nv2AVZdTyClGbVQkIzlDm/rnhk1E9bU9nXwmZ/Vk/iY=
github.com/alecthomas/chroma/v2 v2.23.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY=
github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E=
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex9t5KX76i20Q=
github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q=
github.com/charmbracelet/glamour v1.0.0 h1:AWMLOVFHTsysl4WV8T8QgkQ0s/ZNZo7CiE4WKhk8l08=
github.com/charmbracelet/glamour v1.0.0/go.mod h1:DSdohgOBkMr2ZQNhw4LZxSGpx3SvpeujNoXrQyH2hxo=
github.com/charmbracelet/keygen v0.5.4 h1:XQYgf6UEaTGgQSSmiPpIQ78WfseNQp4Pz8N/c1OsrdA=
github.com/charmbracelet/keygen v0.5.4/go.mod h1:t4oBRr41bvK7FaJsAaAQhhkUuHslzFXVjOBwA55CZNM=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdRc4=
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
github.com/charmbracelet/ssh v0.0.0-20250826160808-ebfa259c7309 h1:dCVbCRRtg9+tsfiTXTp0WupDlHruAXyp+YoxGVofHHc=
github.com/charmbracelet/ssh v0.0.0-20250826160808-ebfa259c7309/go.mod h1:R9cISUs5kAH4Cq/rguNbSwcR+slE5Dfm8FEs//uoIGE=
github.com/charmbracelet/wish v1.4.7 h1:O+jdLac3s6GaqkOHHSwezejNK04vl6VjO1A+hl8J8Yc=
github.com/charmbracelet/wish v1.4.7/go.mod h1:OBZ8vC62JC5cvbxJLh+bIWtG7Ctmct+ewziuUWK+G14=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/conpty v0.2.0 h1:eKtA2hm34qNfgJCDp/M6Dc0gLy7e07YEK4qAdNGOvVY=
github.com/charmbracelet/x/conpty v0.2.0/go.mod h1:fexgUnVrZgw8scD49f6VSi0Ggj9GWYIrpedRthAwW/8=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca h1:QQoyQLgUzojMNWHVHToN6d9qTvT0KWtxUKIRPx/Ox5o=
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
github.com/charmbracelet/x/input v0.3.7 h1:UzVbkt1vgM9dBQ+K+uRolBlN6IF2oLchmPKKo/aucXo=
github.com/charmbracelet/x/input v0.3.7/go.mod h1:ZSS9Cia6Cycf2T6ToKIOxeTBTDwl25AGwArJuGaOBH8=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY=
github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo=
github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM=
github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k=
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
github.com/charmbracelet/bubbletea v1.3.4/go.mod h1:dtcUCyCGEX3g9tosuYiut3MXgY/Jsv9nKVdibKKRRXo=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs=
github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
@@ -94,47 +60,35 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA=
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4=
github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -34,7 +34,8 @@ class CustomBuildHook(BuildHookInterface):
# Build the Go binary (always rebuild to ensure correct version injection)
if not os.path.exists(binary_name):
print(f"Building Go binary '{binary_name}'...")
ldflags = f"-X main.version={tag} -X main.commit={commit} -s -w"
pkg = "github.com/onyx-dot-app/onyx/cli/cmd"
ldflags = f"-X {pkg}.version={tag}" f" -X {pkg}.commit={commit}" " -s -w"
subprocess.check_call( # noqa: S603
["go", "build", f"-ldflags={ldflags}", "-o", binary_name],
)

View File

@@ -270,17 +270,6 @@ func (c *Client) UploadFile(ctx context.Context, filePath string) (*models.FileD
}, nil
}
// GetBackendVersion fetches the backend version string from /api/version.
func (c *Client) GetBackendVersion(ctx context.Context) (string, error) {
var resp struct {
BackendVersion string `json:"backend_version"`
}
if err := c.doJSON(ctx, "GET", "/api/version", nil, &resp); err != nil {
return "", err
}
return resp.BackendVersion, nil
}
// StopChatSession sends a stop signal for a streaming session (best-effort).
func (c *Client) StopChatSession(ctx context.Context, sessionID string) {
req, err := c.newRequest(ctx, "POST", "/api/chat/stop-chat-session/"+sessionID, nil)

View File

@@ -9,10 +9,9 @@ import (
)
const (
EnvServerURL = "ONYX_SERVER_URL"
EnvAPIKey = "ONYX_API_KEY"
EnvServerURL = "ONYX_SERVER_URL"
EnvAPIKey = "ONYX_API_KEY"
EnvAgentID = "ONYX_PERSONA_ID"
EnvSSHHostKey = "ONYX_SSH_HOST_KEY"
)
// OnyxCliConfig holds the CLI configuration.
@@ -36,8 +35,8 @@ func (c OnyxCliConfig) IsConfigured() bool {
return c.APIKey != ""
}
// ConfigDir returns ~/.config/onyx-cli
func ConfigDir() string {
// configDir returns ~/.config/onyx-cli
func configDir() string {
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
return filepath.Join(xdg, "onyx-cli")
}
@@ -50,7 +49,7 @@ func ConfigDir() string {
// ConfigFilePath returns the full path to the config file.
func ConfigFilePath() string {
return filepath.Join(ConfigDir(), "config.json")
return filepath.Join(configDir(), "config.json")
}
// ConfigExists checks if the config file exists on disk.
@@ -88,7 +87,7 @@ func Load() OnyxCliConfig {
// Save writes the config to disk, creating parent directories if needed.
func Save(cfg OnyxCliConfig) error {
dir := ConfigDir()
dir := configDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}

View File

@@ -1,58 +0,0 @@
// Package version provides semver parsing and compatibility checks.
package version
import (
"strconv"
"strings"
)
// Semver holds parsed semantic version components.
type Semver struct {
Major int
Minor int
Patch int
}
// minServer is the minimum backend version required by this CLI.
var minServer = Semver{Major: 3, Minor: 0, Patch: 0}
// MinServer returns the minimum backend version required by this CLI.
func MinServer() Semver { return minServer }
// Parse extracts major, minor, patch from a version string like "3.1.2" or "v3.1.2".
// Returns ok=false if the string is not valid semver.
func Parse(v string) (Semver, bool) {
v = strings.TrimPrefix(v, "v")
// Strip any pre-release suffix (e.g. "-beta.1") and build metadata (e.g. "+build.1")
if idx := strings.IndexAny(v, "-+"); idx != -1 {
v = v[:idx]
}
parts := strings.SplitN(v, ".", 3)
if len(parts) != 3 {
return Semver{}, false
}
major, err := strconv.Atoi(parts[0])
if err != nil {
return Semver{}, false
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return Semver{}, false
}
patch, err := strconv.Atoi(parts[2])
if err != nil {
return Semver{}, false
}
return Semver{Major: major, Minor: minor, Patch: patch}, true
}
// LessThan reports whether s is strictly less than other.
func (s Semver) LessThan(other Semver) bool {
if s.Major != other.Major {
return s.Major < other.Major
}
if s.Minor != other.Minor {
return s.Minor < other.Minor
}
return s.Patch < other.Patch
}

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["hatchling==1.29.0", "go-bin~=1.26.1", "manygo==0.2.0"]
requires = ["hatchling", "go-bin~=1.24.11", "manygo"]
build-backend = "hatchling.build"
[project]

View File

@@ -39,22 +39,6 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -39,20 +39,6 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -39,23 +39,6 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -66,3 +66,10 @@ DB_READONLY_PASSWORD=password
# Show extra/uncommon connectors
# See https://docs.onyx.app/admins/connectors/overview for a full list of connectors
SHOW_EXTRA_CONNECTORS=False
# User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
#SKIP_USERFILE_THRESHOLD=false
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
#SKIP_USERFILE_THRESHOLD_TENANT_IDS=

View File

@@ -35,10 +35,6 @@ USER_AUTH_SECRET=""
## Chat Configuration
# HARD_DELETE_CHATS=
# MAX_ALLOWED_UPLOAD_SIZE_MB=250
# Default per-user upload size limit (MB) when no admin value is set.
# Automatically clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at runtime.
# DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=100
## Base URL for redirects
# WEB_DOMAIN=
@@ -46,6 +42,13 @@ USER_AUTH_SECRET=""
## Enterprise Features, requires a paid plan and licenses
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false
## User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
# SKIP_USERFILE_THRESHOLD=false
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
# SKIP_USERFILE_THRESHOLD_TENANT_IDS=
################################################################################
## SERVICES CONFIGURATIONS

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.38
version: 0.4.36
appVersion: latest
annotations:
category: Productivity

View File

@@ -1,26 +0,0 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docfetching.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9092
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -73,10 +73,6 @@ spec:
"-Q",
"connector_doc_fetching",
]
ports:
- name: metrics
containerPort: 9092
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_docfetching.resources | nindent 12 }}
envFrom:

View File

@@ -1,26 +0,0 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docprocessing.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9093
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -73,10 +73,6 @@ spec:
"-Q",
"docprocessing",
]
ports:
- name: metrics
containerPort: 9093
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_docprocessing.resources | nindent 12 }}
envFrom:

View File

@@ -1,26 +0,0 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9096
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -70,10 +70,6 @@ spec:
"-Q",
"monitoring",
]
ports:
- name: metrics
containerPort: 9096
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_monitoring.resources | nindent 12 }}
envFrom:

View File

@@ -63,22 +63,6 @@ data:
}
{{- end }}
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
# timeout settings
proxy_connect_timeout {{ .Values.nginx.timeouts.connect }}s;
proxy_send_timeout {{ .Values.nginx.timeouts.send }}s;
proxy_read_timeout {{ .Values.nginx.timeouts.read }}s;
proxy_pass http://api_server;
}
location ~ ^/(api|openapi\.json)(/.*)?$ {
rewrite ^/api(/.*)$ $1 break;
proxy_set_header X-Real-IP $remote_addr;

View File

@@ -282,7 +282,7 @@ nginx:
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
# Workaround: Helm upgrade will restart if the following annotation value changes.
podAnnotations:
onyx.app/nginx-config-version: "3"
onyx.app/nginx-config-version: "2"
# Propagate DOMAIN into nginx so server_name continues to use the same env var
extraEnvs:
@@ -1285,5 +1285,11 @@ configMap:
DOMAIN: "localhost"
# Chat Configs
HARD_DELETE_CHATS: ""
MAX_ALLOWED_UPLOAD_SIZE_MB: ""
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB: ""
# User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
SKIP_USERFILE_THRESHOLD: ""
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
# Maximum user upload file size in MB for chat/projects uploads
USER_FILE_MAX_UPLOAD_SIZE_MB: ""

Some files were not shown because too many files have changed in this diff Show More