Compare commits

..

22 Commits

Author SHA1 Message Date
Justin Tahara
4fa3deedb9 fix(image): Cap Uploaded File Image Count (#10298) 2026-04-16 21:34:18 -07:00
Nikolas Garza
593ccbcc66 fix(scim): add advisory lock to prevent seat limit race condition (#10048) to release v3.1 (#10067) 2026-04-10 12:43:08 -07:00
Nikolas Garza
9910487f37 feat(federated): full thread replies + direct URL fetch in Slack search (#9940) to release v3.1 (#10051) 2026-04-09 18:24:08 -07:00
Justin Tahara
d158639844 fix(llm): Azure custom model support + Mistral tool call message ordering (#9729) 2026-04-09 13:58:30 -07:00
Jamison Lahman
6d2bd97412 fix: Custom LLM Provider requires a Provider Name (#10000) 2026-04-08 10:55:58 -07:00
Jamison Lahman
3d48b6a63e fix: LM Studio API key field mismatch (#9991) to release v3.1 (#9992)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-04-08 10:21:38 -07:00
Jamison Lahman
2a7b7c9187 fix: onboarding LLM Provider configuration fixes (#9972) to release v3.1 (#9989)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-04-08 10:07:01 -07:00
github-actions[bot]
c348d1855d feat: generic OpenAI Compatible LLM Provider setup (#9968) to release v3.1 (#9975)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-07 13:20:41 -07:00
github-actions[bot]
b4579a1365 fix(indexing, powerpoint files): Patch markitdown _convert_chart_to_markdown to no-op (#9970) to release v3.1 (#9979)
Co-authored-by: acaprau <48705707+acaprau@users.noreply.github.com>
2026-04-07 13:02:34 -07:00
Justin Tahara
893c094aed fix(groups): Global Curator Permissions (#9974) 2026-04-07 13:01:38 -07:00
Wenxi
f8a55712d2 fix: set correct ee mode for mcp server (#9933) 2026-04-07 09:13:23 -07:00
github-actions[bot]
591afd4fb1 fix: stop falsely rejecting owner-password-only PDFs as protected (#9953) to release v3.1 (#9962)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-06 21:23:13 -07:00
github-actions[bot]
9328070dc0 fix(federated): prevent masked credentials from corrupting stored secrets (#9868) to release v3.1 (#9928) 2026-04-05 16:18:14 -07:00
Jamison Lahman
6163521126 Revert "chore(deps): bump litellm from 1.81.6 to 1.83.0 (#9898) to release v3.1" (#9910) 2026-04-03 18:32:10 -07:00
Jamison Lahman
d42c5616b0 chore(deps): bump litellm from 1.81.6 to 1.83.0 (#9898) to release v3.1 (#9902)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-03 16:18:19 -07:00
Justin Tahara
aeb4fdd6c1 fix(db): remove unnecessary selectinload(User.memories) from auth paths (#9838) 2026-04-01 17:16:57 -07:00
Nikolas Garza
c673959714 fix(celery): use broker connection pool to prevent Redis connection leak (#9682) 2026-03-31 18:40:07 -07:00
Justin Tahara
cb36562802 fix(perf): optimize chat sessions query to prevent DB cascading failures (#9802) 2026-03-31 18:37:38 -07:00
Jessica Singh
efc424bf3e feat(voice): VAD auto-stop only when auto-send is enabled (#9809) 2026-03-31 17:46:28 -07:00
Evan Lohn
e0baaf85e5 fix: Anthropic litellm thinking workaround (#9713) 2026-03-27 14:12:15 -07:00
github-actions[bot]
a0ffd47e2c chore(playwright): deflake settings_pages.spec.ts (#9684) to release v3.1 (#9702)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 09:08:51 -07:00
Jamison Lahman
d0396a1337 fix(fe): Popover content doesnt overflow on small screens (#9612) to release v3.1 (#9700) 2026-03-27 08:43:53 -07:00
155 changed files with 4452 additions and 7185 deletions

View File

@@ -47,8 +47,7 @@ jobs:
done
- name: Publish Helm charts to gh-pages
# NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43
uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit]
uses: stefanprodan/helm-gh-pages@0ad2bb377311d61ac04ad9eb6f252fb68e207260 # ratchet:stefanprodan/helm-gh-pages@v1.7.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
charts_dir: deployment/helm/charts

View File

@@ -13,6 +13,15 @@ jobs:
permissions:
id-token: write
timeout-minutes: 10
strategy:
matrix:
os-arch:
- { goos: "linux", goarch: "amd64" }
- { goos: "linux", goarch: "arm64" }
- { goos: "windows", goarch: "amd64" }
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
@@ -22,11 +31,9 @@ jobs:
enable-cache: false
version: "0.9.9"
- run: |
for goos in linux windows darwin; do
for goarch in amd64 arm64; do
GOOS="$goos" GOARCH="$goarch" uv build --wheel
done
done
GOOS="${{ matrix.os-arch.goos }}" \
GOARCH="${{ matrix.os-arch.goarch }}" \
uv build --wheel
working-directory: cli
- run: uv publish
working-directory: cli

View File

@@ -24,16 +24,6 @@ When hardcoding a boolean variable to a constant value, remove the variable enti
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
## Nginx Routing — New Backend Routes
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
## Full vs Lite Deployments
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.

View File

@@ -122,7 +122,7 @@ repos:
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
hooks:
- id: golangci-lint
language_version: "1.26.1"
language_version: "1.26.0"
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit

View File

@@ -35,7 +35,7 @@ Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep
> [!TIP]
> Run Onyx with one command (or see deployment section below):
> ```
> curl -fsSL https://onyx.app/install_onyx.sh | bash
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
> ```
****

View File

@@ -474,8 +474,6 @@ def connector_permission_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_connector=True,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(

View File

@@ -8,7 +8,6 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import HierarchyNode
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call
@@ -106,11 +105,9 @@ def _get_slack_document_access(
slack_connector: SlackConnector,
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
callback: IndexingHeartbeatInterface | None,
indexing_start: SecondsSinceUnixEpoch | None = None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback,
start=indexing_start,
callback=callback
)
for doc_metadata_batch in slim_doc_generator:
@@ -183,15 +180,9 @@ def slack_doc_sync(
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.set_credentials_provider(provider)
indexing_start_ts: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
yield from _get_slack_document_access(
slack_connector=slack_connector,
slack_connector,
channel_permissions=channel_permissions,
callback=callback,
indexing_start=indexing_start_ts,
)

View File

@@ -6,7 +6,6 @@ from onyx.access.models import ElementExternalAccess
from onyx.access.models import ExternalAccess
from onyx.access.models import NodeExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
@@ -41,19 +40,10 @@ def generic_doc_sync(
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
indexing_start: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
start=indexing_start,
callback=callback,
):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -11,6 +11,8 @@ require a valid SCIM bearer token.
from __future__ import annotations
import hashlib
import struct
from uuid import UUID
from fastapi import APIRouter
@@ -22,6 +24,7 @@ from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -59,9 +62,25 @@ from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Group names reserved for system default groups (seeded by migration).
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
# Namespace prefix for the seat-allocation advisory lock. Hashed together
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
# never block each other) and cannot collide with unrelated advisory locks.
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
return struct.unpack("q", digest[:8])[0]
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
@@ -200,12 +219,37 @@ def _apply_exclusions(
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
"""Return an error message if seat limit is reached, else None.
Acquires a transaction-scoped advisory lock so that concurrent
SCIM requests are serialized. IdPs like Okta send provisioning
requests in parallel batches — without serialization the check is
vulnerable to a TOCTOU race where N concurrent requests each see
"seats available", all insert, and the tenant ends up over its
seat limit.
The lock is held until the caller's next COMMIT or ROLLBACK, which
means the seat count cannot change between the check here and the
subsequent INSERT/UPDATE. Each call site in this module follows
the pattern: _check_seat_availability → write → dal.commit()
(which releases the lock for the next waiting request).
"""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
# The lock id is derived from the tenant so unrelated tenants never block
# each other, and from a namespace string so it cannot collide with
# unrelated advisory locks elsewhere in the codebase.
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
dal.session.execute(
text("SELECT pg_advisory_xact_lock(:lock_id)"),
{"lock_id": lock_id},
)
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"

View File

@@ -44,31 +44,6 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
# User Facing Features Configs
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
# Hard ceiling for the admin-configurable file upload size (in MB).
# Self-hosted customers can raise or lower this via the environment variable.
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
if _raw_max_upload_size_mb < 0:
logger.warning(
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
_raw_max_upload_size_mb,
)
_raw_max_upload_size_mb = 250
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
# Default fallback for the per-user file upload size limit (in MB) when no
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
# runtime so this never silently exceeds the hard ceiling.
_raw_default_upload_size_mb = int(
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
)
if _raw_default_upload_size_mb < 0:
logger.warning(
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
_raw_default_upload_size_mb,
)
_raw_default_upload_size_mb = 100
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
) # 1 day
@@ -86,6 +61,17 @@ CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# Maximum upload size for a single user file (chat/projects) in MB.
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
@@ -830,6 +816,29 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Maximum embedded images allowed in a single file. PDFs (and other formats)
# with thousands of embedded images can OOM the user-file-processing worker
# because every image is decoded with PIL and then sent to the vision LLM.
# Enforced both at upload time (rejects the file) and during extraction
# (defense-in-depth: caps the number of images materialized).
#
# Clamped to >= 0; a negative env value would turn upload validation into
# always-fail and extraction into always-stop, which is never desired. 0
# disables image extraction entirely, which is a valid (if aggressive) setting.
MAX_EMBEDDED_IMAGES_PER_FILE = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
)
# Maximum embedded images allowed across all files in a single upload batch.
# Protects against the scenario where a user uploads many files that each
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
# worker under concurrency or run up surprise latency/cost. Also clamped
# to >= 0.
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
)
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag

View File

@@ -12,6 +12,11 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
MASK_CREDENTIAL_CHAR = "\u2022"
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
SOURCE_TYPE = "source_type"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed

View File

@@ -890,8 +890,8 @@ class ConfluenceConnector(
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
@@ -915,8 +915,8 @@ class ConfluenceConnector(
self.confluence_client, doc_id, restrictions, ancestors
) or space_level_access_info.get(page_space_key)
# Query pages (with optional time filtering for indexing_start)
page_query = self._construct_page_cql_query(start, end)
# Query pages
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -950,9 +950,7 @@ class ConfluenceConnector(
# Query attachments for each page
page_hierarchy_node_yielded = False
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
attachment_query = self._construct_attachment_query(_get_page_id(page))
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,

View File

@@ -1765,11 +1765,7 @@ class SharepointConnector(
checkpoint.current_drive_delta_next_link = None
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateSlimDocumentOutput:
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
@@ -1790,9 +1786,7 @@ class SharepointConnector(
# Process site documents if flag is True
if self.include_site_documents:
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
site_descriptor=site_descriptor,
start=start,
end=end,
site_descriptor=site_descriptor
):
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
@@ -1847,9 +1841,7 @@ class SharepointConnector(
# Process site pages if flag is True
if self.include_site_pages:
site_pages = self._fetch_site_pages(
site_descriptor, start=start, end=end
)
site_pages = self._fetch_site_pages(site_descriptor)
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
@@ -2573,22 +2565,12 @@ class SharepointConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
) -> GenerateSlimDocumentOutput:
start_dt = (
datetime.fromtimestamp(start, tz=timezone.utc)
if start is not None
else None
)
end_dt = (
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
)
yield from self._fetch_slim_documents_from_sharepoint(
start=start_dt,
end=end_dt,
)
yield from self._fetch_slim_documents_from_sharepoint()
if __name__ == "__main__":

View File

@@ -516,8 +516,6 @@ def _get_all_doc_ids(
] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
workspace_url: str | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
@@ -548,8 +546,6 @@ def _get_all_doc_ids(
client=client,
channel=channel,
callback=callback,
oldest=str(start) if start else None, # 0.0 -> None intentionally
latest=str(end) if end is not None else None,
)
for message_batch in channel_message_batches:
@@ -851,8 +847,8 @@ class SlackConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
@@ -865,8 +861,6 @@ class SlackConnector(
msg_filter_func=self.msg_filter_func,
callback=callback,
workspace_url=self._workspace_url,
start=start,
end=end,
)
def _load_from_checkpoint(

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from datetime import datetime
from typing import TypedDict
@@ -6,6 +7,14 @@ from pydantic import BaseModel
from onyx.onyxbot.slack.models import ChannelType
@dataclass(frozen=True)
class DirectThreadFetch:
"""Request to fetch a Slack thread directly by channel and timestamp."""
channel_id: str
thread_ts: str
class ChannelMetadata(TypedDict):
"""Type definition for cached channel metadata."""

View File

@@ -19,6 +19,7 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
@@ -49,7 +50,6 @@ from onyx.server.federated.models import FederatedConnectorDetail
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
logger = setup_logger()
@@ -58,7 +58,6 @@ HIGHLIGHT_END_CHAR = "\ue001"
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
filtered_channels: list[str] # Channels filtered out during this query
def _fetch_thread_from_url(
thread_fetch: DirectThreadFetch,
access_token: str,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
"""Fetch a thread directly from a Slack URL via conversations.replies."""
channel_id = thread_fetch.channel_id
thread_ts = thread_fetch.thread_ts
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
ts=thread_ts,
)
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.warning(
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
)
return SlackQueryResult(messages=[], filtered_channels=[])
if not messages:
logger.warning(
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
)
return SlackQueryResult(messages=[], filtered_channels=[])
# Build thread text from all messages
thread_text = _build_thread_text(messages, access_token, None, slack_client)
# Get channel name from metadata cache or API
channel_name = "unknown"
if channel_metadata_dict and channel_id in channel_metadata_dict:
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
else:
try:
ch_response = slack_client.conversations_info(channel=channel_id)
ch_response.validate()
channel_info: dict[str, Any] = ch_response.get("channel", {})
channel_name = channel_info.get("name", "unknown")
except SlackApiError:
pass
# Build the SlackMessage
parent_msg = messages[0]
message_ts = parent_msg.get("ts", thread_ts)
username = parent_msg.get("user", "unknown_user")
parent_text = parent_msg.get("text", "")
snippet = (
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
).replace("\n", " ")
doc_time = datetime.fromtimestamp(float(message_ts))
decay_factor = DOC_TIME_DECAY
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
permalink = (
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
)
slack_message = SlackMessage(
document_id=f"{channel_id}_{message_ts}",
channel_id=channel_id,
message_id=message_ts,
thread_id=None, # Prevent double-enrichment in thread context fetch
link=permalink,
metadata={
"channel": channel_name,
"time": doc_time.isoformat(),
},
timestamp=doc_time,
recency_bias=recency_bias,
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
text=thread_text,
highlighted_texts=set(),
slack_score=100000.0, # High priority — user explicitly asked for this thread
)
logger.info(
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
)
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
def query_slack(
query_string: str,
access_token: str,
@@ -432,7 +519,6 @@ def query_slack(
available_channels: list[str] | None = None,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
# Check if query has channel override (user specified channels in query)
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
@@ -662,7 +748,6 @@ def _fetch_thread_context(
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# If not a thread, return original text as success
if thread_id is None:
@@ -695,62 +780,37 @@ def _fetch_thread_context(
if len(messages) <= 1:
return ThreadContextResult.success(message.text)
# Build thread text from thread starter + context window around matched message
thread_text = _build_thread_text(
messages, message_id, thread_id, access_token, team_id, slack_client
)
# Build thread text from thread starter + all replies
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
return ThreadContextResult.success(thread_text)
def _build_thread_text(
messages: list[dict[str, Any]],
message_id: str,
thread_id: str,
access_token: str,
team_id: str | None,
slack_client: WebClient,
) -> str:
"""Build the thread text from messages."""
"""Build thread text including all replies.
Includes the thread parent message followed by all replies in order.
"""
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# All messages after index 0 are replies
replies = messages[1:]
if not replies:
return thread_text
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
else:
message_id_idx = next(
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
)
if not message_id_idx:
return thread_text
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
if start_idx > 1:
thread_text += "\n..."
for i in range(start_idx, message_id_idx):
msg_text = messages[i].get("text", "")
msg_sender = messages[i].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add following replies
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
for msg in replies:
msg_text = msg.get("text", "")
msg_sender = msg.get("user", "")
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Replace user IDs with names using cached lookups
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
@@ -976,7 +1036,16 @@ def slack_retrieval(
# Query slack with entity filtering
llm = get_default_llm()
query_strings = build_slack_queries(query, llm, entities, available_channels)
query_items = build_slack_queries(query, llm, entities, available_channels)
# Partition into direct thread fetches and search query strings
direct_fetches: list[DirectThreadFetch] = []
query_strings: list[str] = []
for item in query_items:
if isinstance(item, DirectThreadFetch):
direct_fetches.append(item)
else:
query_strings.append(item)
# Determine filtering based on entities OR context (bot)
include_dm = False
@@ -993,8 +1062,16 @@ def slack_retrieval(
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
)
# Build search tasks
search_tasks = [
# Build search tasks — direct thread fetches + keyword searches
search_tasks: list[tuple] = [
(
_fetch_thread_from_url,
(fetch, access_token, channel_metadata_dict),
)
for fetch in direct_fetches
]
search_tasks.extend(
(
query_slack,
(
@@ -1010,7 +1087,7 @@ def slack_retrieval(
),
)
for query_string in query_strings
]
)
# If include_dm is True AND we're not already searching all channels,
# add additional searches without channel filters.

View File

@@ -10,6 +10,7 @@ from pydantic import ValidationError
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.models import ChunkIndexRequest
from onyx.federated_connectors.slack.models import SlackEntities
from onyx.llm.interfaces import LLM
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
return [query_text]
SLACK_URL_PATTERN = re.compile(
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
)
def extract_slack_message_urls(
query_text: str,
) -> list[tuple[str, str]]:
"""Extract Slack message URLs from query text.
Parses URLs like:
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
Returns list of (channel_id, thread_ts) tuples.
The 16-digit timestamp is converted to Slack ts format (with dot).
"""
results = []
for match in SLACK_URL_PATTERN.finditer(query_text):
channel_id = match.group(1)
raw_ts = match.group(2)
# Convert p1775491616524769 -> 1775491616.524769
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
results.append((channel_id, thread_ts))
return results
def build_slack_queries(
query: ChunkIndexRequest,
llm: LLM,
entities: dict[str, Any] | None = None,
available_channels: list[str] | None = None,
) -> list[str]:
) -> list[str | DirectThreadFetch]:
"""Build Slack query strings with date filtering and query expansion."""
default_search_days = 30
if entities:
@@ -668,6 +695,15 @@ def build_slack_queries(
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
# Check for Slack message URLs — if found, add direct fetch requests
url_fetches: list[DirectThreadFetch] = []
slack_urls = extract_slack_message_urls(query.query)
for channel_id, thread_ts in slack_urls:
url_fetches.append(
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
)
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
# ALWAYS extract channel references from the query (not just for recency queries)
channel_references = extract_channel_references_from_query(query.query)
@@ -684,7 +720,9 @@ def build_slack_queries(
# If valid channels detected, use ONLY those channels with NO keywords
# Return query with ONLY time filter + channel filter (no keywords)
return [build_channel_override_query(channel_references, time_filter)]
return url_fetches + [
build_channel_override_query(channel_references, time_filter)
]
except ValueError as e:
# If validation fails, log the error and continue with normal flow
logger.warning(f"Channel reference validation failed: {e}")
@@ -702,7 +740,8 @@ def build_slack_queries(
rephrased_queries = expand_query_with_llm(query.query, llm)
# Build final query strings with time filters
return [
search_queries = [
rephrased_query.strip() + time_filter
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
]
return url_fetches + search_queries

View File

@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
@@ -132,32 +131,47 @@ def get_chat_sessions_by_user(
if before is not None:
stmt = stmt.where(ChatSession.time_updated < before)
if limit:
stmt = stmt.limit(limit)
if project_id is not None:
stmt = stmt.where(ChatSession.project_id == project_id)
elif only_non_project_chats:
stmt = stmt.where(ChatSession.project_id.is_(None))
if not include_failed_chats:
non_system_message_exists_subq = (
exists()
.where(ChatMessage.chat_session_id == ChatSession.id)
.where(ChatMessage.message_type != MessageType.SYSTEM)
.correlate(ChatSession)
)
# Leeway for newly created chats that don't have messages yet
time = datetime.now(timezone.utc) - timedelta(minutes=5)
recently_created = ChatSession.time_created >= time
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
# When filtering out failed chats, we apply the limit in Python after
# filtering rather than in SQL, since the post-filter may remove rows.
if limit and include_failed_chats:
stmt = stmt.limit(limit)
result = db_session.execute(stmt)
chat_sessions = result.scalars().all()
chat_sessions = list(result.scalars().all())
return list(chat_sessions)
if not include_failed_chats and chat_sessions:
# Filter out "failed" sessions (those with only SYSTEM messages)
# using a separate efficient query instead of a correlated EXISTS
# subquery, which causes full sequential scans of chat_message.
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
if session_ids:
valid_session_ids_stmt = (
select(ChatMessage.chat_session_id)
.where(ChatMessage.chat_session_id.in_(session_ids))
.where(ChatMessage.message_type != MessageType.SYSTEM)
.distinct()
)
valid_session_ids = set(
db_session.execute(valid_session_ids_stmt).scalars().all()
)
chat_sessions = [
cs
for cs in chat_sessions
if cs.time_created >= leeway or cs.id in valid_session_ids
]
if limit:
chat_sessions = chat_sessions[:limit]
return chat_sessions
def delete_orphaned_search_docs(db_session: Session) -> None:

View File

@@ -8,6 +8,8 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.constants import FederatedConnectorSource
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
@@ -45,6 +47,23 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
return fetch_all_federated_connectors(db_session)
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
"""Raise if any credential string value contains mask placeholder characters.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
Both must be rejected.
"""
for key, val in credentials.items():
if isinstance(val, str) and (
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
):
raise ValueError(
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
)
def validate_federated_connector_credentials(
source: FederatedConnectorSource,
credentials: dict[str, Any],
@@ -66,6 +85,8 @@ def create_federated_connector(
config: dict[str, Any] | None = None,
) -> FederatedConnector:
"""Create a new federated connector with credential and config validation."""
_reject_masked_credentials(credentials)
# Validate credentials before creating
if not validate_federated_connector_credentials(source, credentials):
raise ValueError(
@@ -277,6 +298,8 @@ def update_federated_connector(
)
if credentials is not None:
_reject_masked_credentials(credentials)
# Validate credentials before updating
if not validate_federated_connector_credentials(
federated_connector.source, credentials

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -47,7 +46,6 @@ async def fetch_user_for_pat(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
)
if not user:
return None

View File

@@ -229,7 +229,9 @@ def get_memories_for_user(
user_id: UUID,
db_session: Session,
) -> Sequence[Memory]:
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
return db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
).all()
def update_user_pinned_assistants(

View File

@@ -5,7 +5,6 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -67,7 +66,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -207,7 +206,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -227,8 +226,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -210,10 +209,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.

View File

@@ -1,5 +1,5 @@
import json
from collections.abc import Iterable
from collections import defaultdict
from typing import Any
import httpx
@@ -351,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -647,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
Groups chunks by document ID and for each document, deletes existing
chunks and indexes the new chunks in bulk.
@@ -673,34 +673,29 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
assert len(doc_chunks) > 0, "doc_chunks is empty"
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
]
onyx_document: Document = doc_chunks[0].source_document
onyx_document: Document = chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -709,40 +704,22 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
document_indexing_results.append(document_insertion_record)
return document_indexing_results

View File

@@ -6,7 +6,6 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -462,7 +461,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,8 +1,6 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -320,7 +318,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -340,31 +338,22 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
existing_docs: set[str] = set()
@@ -420,13 +409,7 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(
chunks=chunk_batch,
@@ -436,6 +419,10 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],

View File

@@ -21,6 +21,7 @@ import chardet
import openpyxl
from PIL import Image
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.constants import ONYX_METADATA_FILENAME
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.file_processing.file_types import OnyxFileExtensions
@@ -44,15 +45,26 @@ KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
"Unable to read workbook: could not read stylesheet from None",
"Colors must be aRGB hex values",
]
def get_markitdown_converter() -> "MarkItDown":
global _MARKITDOWN_CONVERTER
from markitdown import MarkItDown
if _MARKITDOWN_CONVERTER is None:
from markitdown import MarkItDown
# Patch this function to effectively no-op because we were seeing this
# module take an inordinate amount of time to convert charts to markdown,
# making some powerpoint files with many or complicated charts nearly
# unindexable.
from markitdown.converters._pptx_converter import PptxConverter
setattr(
PptxConverter,
"_convert_chart_to_markdown",
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
)
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
return _MARKITDOWN_CONVERTER
@@ -177,6 +189,56 @@ def read_text_file(
return file_content_raw, metadata
def count_pdf_embedded_images(file: IO[Any], cap: int) -> int:
"""Return the number of embedded images in a PDF, short-circuiting at cap+1.
Used to reject PDFs whose image count would OOM the user-file-processing
worker during indexing. Returns a value > cap as a sentinel once the count
exceeds the cap, so callers do not iterate thousands of image objects just
to report a number. Returns 0 if the PDF cannot be parsed.
Owner-password-only PDFs (permission restrictions but no open password) are
counted normally — they decrypt with an empty string. Truly password-locked
PDFs are skipped (return 0) since we can't inspect them; the caller should
ensure the password-protected check runs first.
Always restores the file pointer to its original position before returning.
"""
from pypdf import PdfReader
try:
start_pos = file.tell()
except Exception:
start_pos = None
try:
if start_pos is not None:
file.seek(0)
reader = PdfReader(file)
if reader.is_encrypted:
# Try empty password first (owner-password-only PDFs); give up if that fails.
try:
if reader.decrypt("") == 0:
return 0
except Exception:
return 0
count = 0
for page in reader.pages:
for _ in page.images:
count += 1
if count > cap:
return count
return count
except Exception:
logger.warning("Failed to count embedded images in PDF", exc_info=True)
return 0
finally:
if start_pos is not None:
try:
file.seek(start_pos)
except Exception:
pass
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
"""
Extract text from a PDF. For embedded images, a more complex approach is needed.
@@ -203,18 +265,26 @@ def read_pdf_file(
try:
pdf_reader = PdfReader(file)
if pdf_reader.is_encrypted and pdf_pass is not None:
if pdf_reader.is_encrypted:
# Try the explicit password first, then fall back to an empty
# string. Owner-password-only PDFs (permission restrictions but
# no open password) decrypt successfully with "".
# See https://github.com/onyx-dot-app/onyx/issues/9754
passwords = [p for p in [pdf_pass, ""] if p is not None]
decrypt_success = False
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error("Unable to decrypt pdf")
for pw in passwords:
try:
if pdf_reader.decrypt(pw) != 0:
decrypt_success = True
break
except Exception:
pass
if not decrypt_success:
logger.error(
"Encrypted PDF could not be decrypted, returning empty text."
)
return "", metadata, []
elif pdf_reader.is_encrypted:
logger.warning("No Password for an encrypted PDF, returning empty text.")
return "", metadata, []
# Basic PDF metadata
if pdf_reader.metadata is not None:
@@ -232,8 +302,27 @@ def read_pdf_file(
)
if extract_images:
image_cap = MAX_EMBEDDED_IMAGES_PER_FILE
images_processed = 0
cap_reached = False
for page_num, page in enumerate(pdf_reader.pages):
if cap_reached:
break
for image_file_object in page.images:
if images_processed >= image_cap:
# Defense-in-depth backstop. Upload-time validation
# should have rejected files exceeding the cap, but
# we also break here so a single oversized file can
# never pin a worker.
logger.warning(
"PDF embedded image cap reached (%d). "
"Skipping remaining images on page %d and beyond.",
image_cap,
page_num + 1,
)
cap_reached = True
break
image = Image.open(io.BytesIO(image_file_object.data))
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format=image.format)
@@ -246,6 +335,7 @@ def read_pdf_file(
image_callback(img_bytes, image_name)
else:
extracted_images.append((img_bytes, image_name))
images_processed += 1
return text, metadata, extracted_images

View File

@@ -33,8 +33,20 @@ def is_pdf_protected(file: IO[Any]) -> bool:
with preserve_position(file):
reader = PdfReader(file)
if not reader.is_encrypted:
return False
return bool(reader.is_encrypted)
# PDFs with only an owner password (permission restrictions like
# print/copy disabled) use an empty user password — any viewer can open
# them without prompting. decrypt("") returns 0 only when a real user
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
try:
return reader.decrypt("") == 0
except Exception:
logger.exception(
"Failed to evaluate PDF encryption; treating as password protected"
)
return True
def is_docx_protected(file: IO[Any]) -> bool:

View File

@@ -29,7 +29,6 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
@@ -174,10 +173,8 @@ class UserFileIndexingAdapter:
[chunk.content for chunk in user_file_chunks]
)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count: int = (
count_tokens(combined_content, llm_tokenizer)
if llm_tokenizer
else 0
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
)
user_file_id_to_token_count[str(user_file_id)] = token_count
else:

View File

@@ -26,6 +26,7 @@ class LlmProviderNames(str, Enum):
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
BIFROST = "bifrost"
OPENAI_COMPATIBLE = "openai_compatible"
def __str__(self) -> str:
"""Needed so things like:
@@ -46,6 +47,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
]
@@ -64,6 +66,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
LlmProviderNames.BIFROST: "Bifrost",
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -116,6 +119,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
# Model family name mappings for display name generation

View File

@@ -175,6 +175,28 @@ def _strip_tool_content_from_messages(
return result
def _fix_tool_user_message_ordering(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Insert a synthetic assistant message between tool and user messages.
Some models (e.g. Mistral on Azure) require strict message ordering where
a user message cannot immediately follow a tool message. This function
inserts a minimal assistant message to bridge the gap.
"""
if len(messages) < 2:
return messages
result: list[dict[str, Any]] = [messages[0]]
for msg in messages[1:]:
prev_role = result[-1].get("role")
curr_role = msg.get("role")
if prev_role == "tool" and curr_role == "user":
result.append({"role": "assistant", "content": "Noted. Continuing."})
result.append(msg)
return result
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
"""Check if any messages contain tool-related content blocks."""
for msg in messages:
@@ -305,12 +327,19 @@ class LitellmLLM(LLM):
):
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
# Bifrost: OpenAI-compatible proxy that expects model names in
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
# We route through LiteLLM's openai provider with the Bifrost base URL,
# and ensure /v1 is appended.
if model_provider == LlmProviderNames.BIFROST:
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
# model names directly to the endpoint. We route through LiteLLM's
# openai provider with the server's base URL, and ensure /v1 is appended.
if model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
):
self._custom_llm_provider = "openai"
# LiteLLM's OpenAI client requires an api_key to be set.
# Many OpenAI-compatible servers don't need auth, so supply a
# placeholder to prevent LiteLLM from raising AuthenticationError.
if not self._api_key:
model_kwargs.setdefault("api_key", "not-needed")
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
@@ -427,17 +456,20 @@ class LitellmLLM(LLM):
optional_kwargs: dict[str, Any] = {}
# Model name
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
is_openai_compatible_proxy = self._model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
)
model_provider = (
f"{self.config.model_provider}/responses"
if is_openai_model # Uses litellm's completions -> responses bridge
else self.config.model_provider
)
if is_bifrost:
# Bifrost expects model names in provider/model format
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
# so LiteLLM doesn't try to route based on the provider prefix.
if is_openai_compatible_proxy:
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
# servers) expect model names sent directly to their endpoint.
# We use custom_llm_provider="openai" so LiteLLM doesn't try
# to route based on the provider prefix.
model = self.config.deployment_name or self.config.model_name
else:
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
@@ -528,7 +560,10 @@ class LitellmLLM(LLM):
if structured_response_format:
optional_kwargs["response_format"] = structured_response_format
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
if (
not (is_claude_model or is_ollama or is_mistral)
or is_openai_compatible_proxy
):
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
# However, this param breaks Anthropic and Mistral models,
# so it must be conditionally included unless the request is
@@ -576,6 +611,18 @@ class LitellmLLM(LLM):
):
messages = _strip_tool_content_from_messages(messages)
# Some models (e.g. Mistral) reject a user message
# immediately after a tool message. Insert a synthetic
# assistant bridge message to satisfy the ordering
# constraint. Check both the provider and the deployment/
# model name to catch Mistral hosted on Azure.
model_or_deployment = (
self._deployment_name or self._model_version or ""
).lower()
is_mistral_model = is_mistral or "mistral" in model_or_deployment
if is_mistral_model:
messages = _fix_tool_user_message_ordering(messages)
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
# reject requests where tool_choice is explicitly null.
if tools and tool_choice is not None:

View File

@@ -15,6 +15,8 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
BIFROST_PROVIDER_NAME = "bifrost"
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -19,6 +19,7 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
@@ -51,6 +52,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
}
@@ -336,6 +338,7 @@ def get_provider_display_name(provider_name: str) -> str:
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
OPENROUTER_PROVIDER_NAME: "OpenRouter",
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
}
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:

View File

@@ -6,6 +6,7 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
from onyx.configs.app_configs import MCP_SERVER_HOST
from onyx.configs.app_configs import MCP_SERVER_PORT
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
logger = setup_logger()
@@ -16,6 +17,7 @@ def main() -> None:
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
return
set_is_ee_based_on_env_variable()
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
from onyx.mcp_server.api import mcp_app

View File

@@ -175,32 +175,6 @@ def get_tokenizer(
return _check_tokenizer_cache(provider_type, model_name)
# Max characters per encode() call.
_ENCODE_CHUNK_SIZE = 500_000
def count_tokens(
text: str,
tokenizer: BaseTokenizer,
token_limit: int | None = None,
) -> int:
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
If token_limit is provided and the text is large enough to require
multiple chunks (> 500k chars), stops early once the count exceeds it.
When early-exiting, the returned value exceeds token_limit but may be
less than the true full token count.
"""
if len(text) <= _ENCODE_CHUNK_SIZE:
return len(tokenizer.encode(text))
total = 0
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
if token_limit is not None and total > token_limit:
return total # Already over — skip remaining chunks
return total
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:

View File

@@ -40,6 +40,8 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.client import app as celery_app
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -50,6 +52,9 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentMetadata
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD
from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES
@@ -127,6 +132,49 @@ class DeleteFileResponse(BaseModel):
# =============================================================================
def _looks_like_pdf(filename: str, content_type: str | None) -> bool:
"""True if either the filename or the content-type indicates a PDF.
Client-supplied ``content_type`` can be spoofed (e.g. a PDF uploaded with
``Content-Type: application/octet-stream``), so we also fall back to
extension-based detection via ``mimetypes.guess_type`` on the filename.
"""
if content_type == "application/pdf":
return True
guessed, _ = mimetypes.guess_type(filename)
return guessed == "application/pdf"
def _check_pdf_image_caps(
filename: str, content: bytes, content_type: str | None, batch_total: int
) -> int:
"""Enforce per-file and per-batch embedded-image caps for PDFs.
Returns the number of embedded images in this file (0 for non-PDFs) so
callers can update their running batch total. Raises OnyxError(INVALID_INPUT)
if either cap is exceeded.
"""
if not _looks_like_pdf(filename, content_type):
return 0
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
# Short-circuit at the larger cap so we get a useful count for both checks.
count = count_pdf_embedded_images(BytesIO(content), max(file_cap, batch_cap))
if count > file_cap:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"PDF '{filename}' contains too many embedded images "
f"(more than {file_cap}). Try splitting the document into smaller files.",
)
if batch_total + count > batch_cap:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"Upload would exceed the {batch_cap}-image limit across all "
f"files in this batch. Try uploading fewer image-heavy files at once.",
)
return count
def _sanitize_path(path: str) -> str:
"""Sanitize a file path, removing traversal attempts and normalizing.
@@ -355,6 +403,7 @@ async def upload_files(
uploaded_entries: list[LibraryEntryResponse] = []
total_size = 0
batch_image_total = 0
now = datetime.now(timezone.utc)
# Sanitize the base path
@@ -374,6 +423,14 @@ async def upload_files(
detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB",
)
# Reject PDFs with an unreasonable per-file or per-batch image count
batch_image_total += _check_pdf_image_caps(
filename=file.filename or "unnamed",
content=content,
content_type=file.content_type,
batch_total=batch_image_total,
)
# Validate cumulative storage (existing + this upload batch)
total_size += file_size
if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES:
@@ -472,6 +529,7 @@ async def upload_zip(
uploaded_entries: list[LibraryEntryResponse] = []
total_size = 0
batch_image_total = 0
# Extract zip contents into a subfolder named after the zip file
zip_name = api_sanitize_filename(file.filename or "upload")
@@ -510,6 +568,36 @@ async def upload_zip(
logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size")
continue
# Skip PDFs that would trip the per-file or per-batch image
# cap (would OOM the user-file-processing worker). Matches
# /upload behavior but uses skip-and-warn to stay consistent
# with the zip path's handling of oversized files.
zip_file_name = zip_info.filename.split("/")[-1]
zip_content_type, _ = mimetypes.guess_type(zip_file_name)
if zip_content_type == "application/pdf":
image_count = count_pdf_embedded_images(
BytesIO(file_content),
max(
MAX_EMBEDDED_IMAGES_PER_FILE,
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
),
)
if image_count > MAX_EMBEDDED_IMAGES_PER_FILE:
logger.warning(
"Skipping '%s' - exceeds %d per-file embedded-image cap",
zip_info.filename,
MAX_EMBEDDED_IMAGES_PER_FILE,
)
continue
if batch_image_total + image_count > MAX_EMBEDDED_IMAGES_PER_UPLOAD:
logger.warning(
"Skipping '%s' - would exceed %d per-batch embedded-image cap",
zip_info.filename,
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
)
continue
batch_image_total += image_count
total_size += file_size
# Validate cumulative storage

View File

@@ -3844,9 +3844,9 @@
}
},
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
"version": "5.0.5",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
"integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==",
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
"license": "MIT",
"dependencies": {
"balanced-match": "^4.0.2"
@@ -4224,9 +4224,9 @@
}
},
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -5007,9 +5007,9 @@
}
},
"node_modules/brace-expansion": {
"version": "1.1.13",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
"version": "1.1.12",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
"dev": true,
"license": "MIT",
"dependencies": {

View File

@@ -123,8 +123,9 @@ def _validate_endpoint(
(not reachable — indicates the api_key is invalid).
Timeout handling:
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
timeout (operator should consider increasing timeout_seconds).
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)

View File

@@ -9,15 +9,23 @@ from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -156,8 +164,8 @@ def categorize_uploaded_files(
document formats (.pdf, .docx, …) and falls back to a text-detection
heuristic for unknown extensions (.py, .js, .rs, …).
- Uses default tokenizer to compute token length.
- If token length exceeds the admin-configured threshold, reject file.
- If extension unsupported or text cannot be extracted, reject file.
- If token length > threshold, reject file (unless threshold skip is enabled).
- If text cannot be extracted, reject file.
- Otherwise marked as acceptable.
"""
@@ -168,33 +176,41 @@ def categorize_uploaded_files(
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
# Derive limits from admin-configurable settings.
# For upload size: load_settings() resolves 0/None to a positive default.
# For token threshold: 0 means "no limit" (converted to None below).
settings = load_settings()
max_upload_size_mb = (
settings.user_file_max_upload_size_mb
) # always positive after load_settings()
max_upload_size_bytes = (
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
)
token_threshold_k = settings.file_token_count_threshold_k
token_threshold = (
token_threshold_k * 1000 if token_threshold_k else None
) # 0 → None = no limit
# Check if threshold checks should be skipped
skip_threshold = False
# Check global skip flag (works for both single-tenant and multi-tenant)
if SKIP_USERFILE_THRESHOLD:
skip_threshold = True
logger.info("Skipping userfile threshold check (global setting)")
# Check tenant-specific skip list (only applicable in multi-tenant)
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
try:
current_tenant_id = get_current_tenant_id()
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
if skip_threshold:
logger.info(
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
)
except RuntimeError as e:
logger.warning(f"Failed to get current tenant ID: {str(e)}")
# Running total of embedded images across PDFs in this batch. Once the
# aggregate cap is reached, subsequent PDFs in the same upload are
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
batch_image_total = 0
for upload in files:
try:
filename = get_safe_filename(upload)
# Size limit is a hard safety cap.
if max_upload_size_bytes is not None and is_upload_too_large(
upload, max_upload_size_bytes
):
# Size limit is a hard safety cap and is enforced even when token
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
)
)
continue
@@ -216,11 +232,11 @@ def categorize_uploaded_files(
)
continue
if token_threshold is not None and token_count > token_threshold:
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {token_threshold_k}K token limit",
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
)
)
else:
@@ -245,6 +261,47 @@ def categorize_uploaded_files(
)
continue
# Reject PDFs with an unreasonable number of embedded images
# (either per-file or accumulated across this upload batch).
# A PDF with thousands of embedded images can OOM the
# user-file-processing celery worker because every image is
# decoded with PIL and then sent to the vision LLM.
if extension == ".pdf":
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
# Use the larger of the two caps as the short-circuit
# threshold so we get a useful count for both checks.
# count_pdf_embedded_images restores the stream position.
count = count_pdf_embedded_images(
upload.file, max(file_cap, batch_cap)
)
if count > file_cap:
results.rejected.append(
RejectedFile(
filename=filename,
reason=(
f"PDF contains too many embedded images "
f"(more than {file_cap}). Try splitting "
f"the document into smaller files."
),
)
)
continue
if batch_image_total + count > batch_cap:
results.rejected.append(
RejectedFile(
filename=filename,
reason=(
f"Upload would exceed the "
f"{batch_cap}-image limit across all "
f"files in this batch. Try uploading "
f"fewer image-heavy files at once."
),
)
)
continue
batch_image_total += count
text_content = extract_file_text(
file=upload.file,
file_name=filename,
@@ -261,14 +318,12 @@ def categorize_uploaded_files(
)
continue
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
if token_threshold is not None and token_count > token_threshold:
token_count = len(tokenizer.encode(text_content))
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {token_threshold_k}K token limit",
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
)
)
else:

View File

@@ -74,6 +74,8 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenAICompatibleFinalModelResponse
from onyx.server.manage.llm.models import OpenAICompatibleModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
@@ -1575,3 +1577,95 @@ def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> d
source_name="Bifrost",
api_key=api_key,
)
@admin_router.post("/openai-compatible/available-models")
def get_openai_compatible_server_available_models(
request: OpenAICompatibleModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[OpenAICompatibleFinalModelResponse]:
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
response_json = _get_openai_compatible_server_response(
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your OpenAI-compatible endpoint",
)
results: list[OpenAICompatibleFinalModelResponse] = []
for model in models:
try:
model_id = model.get("id", "")
model_name = model.get("name", model_id)
if not model_id:
continue
# Skip embedding models
if is_embedding_model(model_id):
continue
results.append(
OpenAICompatibleFinalModelResponse(
name=model_id,
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:
logger.warning(
"Failed to parse OpenAI-compatible model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from OpenAI-compatible endpoint",
)
sorted_results = sorted(results, key=lambda m: m.name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="OpenAI Compatible",
)
return sorted_results
def _get_openai_compatible_server_response(
api_base: str, api_key: str | None = None
) -> dict:
"""Perform GET to an OpenAI-compatible /v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
# Ensure we hit /v1/models
if cleaned_api_base.endswith("/v1"):
url = f"{cleaned_api_base}/models"
else:
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="OpenAI Compatible",
api_key=api_key,
)

View File

@@ -464,3 +464,18 @@ class BifrostFinalModelResponse(BaseModel):
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool
# OpenAI Compatible dynamic models fetch
class OpenAICompatibleModelsRequest(BaseModel):
api_base: str
api_key: str | None = None
provider_name: str | None = None # Optional: to save models to existing provider
class OpenAICompatibleFinalModelResponse(BaseModel):
name: str # Model ID (e.g. "meta-llama/Llama-3-8B-Instruct")
display_name: str # Human-readable name from API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -26,6 +26,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
)

View File

@@ -147,6 +147,7 @@ class UserInfo(BaseModel):
is_anonymous_user: bool | None = None,
tenant_info: TenantInfo | None = None,
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
memories: list[MemoryItem] | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
@@ -191,10 +192,7 @@ class UserInfo(BaseModel):
role=user.personal_role or "",
use_memories=user.use_memories,
enable_memory_tool=user.enable_memory_tool,
memories=[
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in (user.memories or [])
],
memories=memories or [],
user_preferences=user.user_preferences or "",
),
)

View File

@@ -57,6 +57,7 @@ from onyx.db.user_preferences import activate_user
from onyx.db.user_preferences import deactivate_user
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
from onyx.db.user_preferences import get_latest_access_token_for_user
from onyx.db.user_preferences import get_memories_for_user
from onyx.db.user_preferences import update_assistant_preferences
from onyx.db.user_preferences import update_user_assistant_visibility
from onyx.db.user_preferences import update_user_auto_scroll
@@ -823,6 +824,11 @@ def verify_user_logged_in(
[],
),
)
memories = [
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in get_memories_for_user(user.id, db_session)
]
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
@@ -833,6 +839,7 @@ def verify_user_logged_in(
new_tenant=new_tenant,
invitation=tenant_invitation,
),
memories=memories,
)
return user_info
@@ -930,7 +937,8 @@ def update_user_personalization_api(
else user.enable_memory_tool
)
existing_memories = [
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in get_memories_for_user(user.id, db_session)
]
new_memories = (
request.memories if request.memories is not None else existing_memories

View File

@@ -9,9 +9,7 @@ from onyx import __version__ as onyx_version
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import is_user_admin
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.constants import NotificationType
from onyx.db.engine.sql_engine import get_session
@@ -19,16 +17,10 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Notification
from onyx.server.settings.models import Settings
from onyx.server.settings.models import UserSettings
@@ -49,15 +41,6 @@ basic_router = APIRouter(prefix="/settings")
def admin_put_settings(
settings: Settings, _: User = Depends(current_admin_user)
) -> None:
if (
settings.user_file_max_upload_size_mb is not None
and settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
)
store_settings(settings)
@@ -100,16 +83,6 @@ def fetch_settings(
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
),
default_file_token_count_threshold_k=(
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
),
)

View File

@@ -2,19 +2,12 @@ from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import NotificationType
from onyx.configs.constants import QueryHistoryType
from onyx.db.models import Notification as NotificationDBModel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
class PageType(str, Enum):
CHAT = "chat"
@@ -85,12 +78,7 @@ class Settings(BaseModel):
# User Knowledge settings
user_knowledge_enabled: bool | None = True
user_file_max_upload_size_mb: int | None = Field(
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
)
file_token_count_threshold_k: int | None = Field(
default=None, ge=0 # thousands of tokens; None = context-aware default
)
user_file_max_upload_size_mb: int | None = None
# Connector settings
show_extra_connectors: bool | None = True
@@ -120,14 +108,3 @@ class UserSettings(Settings):
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
# Factory defaults so the frontend can show a "restore default" button.
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
default_file_token_count_threshold_k: int = Field(
default_factory=lambda: (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
)

View File

@@ -1,19 +1,13 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -57,36 +51,9 @@ def load_settings() -> Settings:
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
# Resolve context-aware defaults for token threshold.
# None = admin hasn't set a value yet → use context-aware default.
# 0 = admin explicitly chose "no limit" → preserve as-is.
if settings.file_token_count_threshold_k is None:
settings.file_token_count_threshold_k = (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
# Upload size: 0 and None are treated as "unset" (not "no limit") →
# fall back to min(configured default, hard ceiling).
if not settings.user_file_max_upload_size_mb:
settings.user_file_max_upload_size_mb = min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
)
# Clamp to env ceiling so stale KV values are capped even if the
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
# was already saved (api.py only guards new writes).
if (
settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
return settings

View File

@@ -187,7 +187,7 @@ coloredlogs==15.0.1
# via onnxruntime
courlan==1.3.2
# via trafilatura
cryptography==46.0.6
cryptography==46.0.5
# via
# authlib
# google-auth
@@ -449,7 +449,7 @@ kombu==5.5.4
# via celery
kubernetes==31.0.0
# via onyx
langchain-core==1.2.22
langchain-core==1.2.11
# via onyx
langdetect==1.0.9
# via unstructured

View File

@@ -97,7 +97,7 @@ comm==0.2.3
# via ipykernel
contourpy==1.3.3
# via matplotlib
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt
@@ -263,7 +263,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.7.2
onyx-devtools==0.7.1
# via onyx
openai==2.14.0
# via

View File

@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt

View File

@@ -92,7 +92,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.6
cryptography==46.0.5
# via
# google-auth
# pyjwt

View File

@@ -191,6 +191,25 @@ IGNORED_SYNCING_TENANT_LIST = (
else None
)
# Global flag to skip userfile threshold for all users/tenants
SKIP_USERFILE_THRESHOLD = (
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
)
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
)
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
[
tenant.strip()
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
if tenant.strip()
]
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
else None
)
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"

View File

@@ -1,6 +1,4 @@
import time
from datetime import datetime
from datetime import timezone
import pytest
@@ -19,10 +17,6 @@ PRIVATE_CHANNEL_USERS = [
"test_user_2@onyx-test.com",
]
# Predates any test workspace messages, so the result set should match
# the "no start time" case while exercising the oldest= parameter.
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
pytestmark = pytest.mark.usefixtures("enable_ee")
@@ -111,17 +105,15 @@ def test_load_from_checkpoint_access__private_channel(
],
indirect=True,
)
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
def test_slim_documents_access__public_channel(
slack_connector: SlackConnector,
start_ts: float | None,
) -> None:
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
if not slack_connector.client:
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=start_ts,
start=0.0,
end=time.time(),
)
@@ -157,7 +149,7 @@ def test_slim_documents_access__private_channel(
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=None,
start=0.0,
end=time.time(),
)

View File

@@ -6,7 +6,6 @@ These tests assume Vespa and OpenSearch are running.
import time
import uuid
from collections.abc import Generator
from collections.abc import Iterator
import httpx
import pytest
@@ -22,7 +21,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import DocMetadataAwareIndexChunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
@@ -203,25 +201,3 @@ class TestDocumentIndexNew:
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk(doc_id, chunk_id=i)
results = document_index.index(
chunks=chunk_gen(), indexing_metadata=metadata
)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False

View File

@@ -5,7 +5,6 @@ These tests assume Vespa and OpenSearch are running.
import time
from collections.abc import Generator
from collections.abc import Iterator
import pytest
@@ -167,29 +166,3 @@ class TestDocumentIndexOld:
batch_retrieval=True,
)
assert len(inference_chunks) == 0
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndex],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk("test_doc_gen", chunk_id=i)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
results = document_index.index(chunk_gen(), index_batch_params)
assert len(results) == 1
record = results.pop()
assert record.document_id == "test_doc_gen"
assert record.already_existed is False

View File

@@ -0,0 +1,58 @@
import pytest
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.db.federated import _reject_masked_credentials
class TestRejectMaskedCredentials:
"""Verify that masked credential values are never accepted for DB writes.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
_reject_masked_credentials must catch both.
"""
def test_rejects_fully_masked_value(self) -> None:
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
with pytest.raises(ValueError, match="masked placeholder"):
_reject_masked_credentials({"client_id": masked})
def test_rejects_long_string_masked_value(self) -> None:
"""mask_string returns 'first4...last4' for long strings — the real
format used for OAuth credentials like client_id and client_secret."""
with pytest.raises(ValueError, match="masked placeholder"):
_reject_masked_credentials({"client_id": "1234...7890"})
def test_rejects_when_any_field_is_masked(self) -> None:
"""Even if client_id is real, a masked client_secret must be caught."""
with pytest.raises(ValueError, match="client_secret"):
_reject_masked_credentials(
{
"client_id": "1234567890.1234567890",
"client_secret": MASK_CREDENTIAL_CHAR * 12,
}
)
def test_accepts_real_credentials(self) -> None:
# Should not raise
_reject_masked_credentials(
{
"client_id": "1234567890.1234567890",
"client_secret": "test_client_secret_value",
}
)
def test_accepts_empty_dict(self) -> None:
# Should not raise — empty credentials are handled elsewhere
_reject_masked_credentials({})
def test_ignores_non_string_values(self) -> None:
# Non-string values (None, bool, int) should pass through
_reject_masked_credentials(
{
"client_id": "real_value",
"redirect_uri": None,
"some_flag": True,
}
)

View File

@@ -1,5 +1,3 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -33,7 +31,6 @@ def mock_jira_cc_pair(
"jira_base_url": jira_base_url,
"project_key": project_key,
}
mock_cc_pair.connector.indexing_start = None
return mock_cc_pair
@@ -68,75 +65,3 @@ def test_jira_permission_sync(
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
):
print(doc)
def test_jira_doc_sync_passes_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that generic_doc_sync derives indexing_start from cc_pair
and forwards it to retrieve_all_slim_docs_perm_sync."""
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
def test_jira_doc_sync_passes_none_when_no_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that indexing_start is None when the connector has no indexing_start set."""
mock_jira_cc_pair.connector.indexing_start = None
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] is None

View File

@@ -0,0 +1,67 @@
"""Tests for _build_thread_text function."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.context.search.federated.slack_search import _build_thread_text
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
return {"user": user, "text": text, "ts": ts}
class TestBuildThreadText:
"""Verify _build_thread_text includes full thread replies up to cap."""
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
"""All replies within cap are included in output."""
mock_profiles.return_value = {}
messages = [
_make_msg("U1", "parent msg", "1000.0"),
_make_msg("U2", "reply 1", "1001.0"),
_make_msg("U3", "reply 2", "1002.0"),
_make_msg("U4", "reply 3", "1003.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "parent msg" in result
assert "reply 1" in result
assert "reply 2" in result
assert "reply 3" in result
assert "..." not in result
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
"""Single message (no replies) returns just the parent text."""
mock_profiles.return_value = {}
messages = [_make_msg("U1", "just a message", "1000.0")]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "just a message" in result
assert "Replies:" not in result
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
"""Thread parent message is always the first line of output."""
mock_profiles.return_value = {}
messages = [
_make_msg("U1", "I am the parent", "1000.0"),
_make_msg("U2", "I am a reply", "1001.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
parent_pos = result.index("I am the parent")
reply_pos = result.index("I am a reply")
assert parent_pos < reply_pos
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
"""User IDs in thread text are replaced with display names."""
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
messages = [
_make_msg("U1", "hello", "1000.0"),
_make_msg("U2", "world", "1001.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "Alice" in result
assert "Bob" in result
assert "<@U1>" not in result
assert "<@U2>" not in result

View File

@@ -0,0 +1,108 @@
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
class TestExtractSlackMessageUrls:
"""Verify URL parsing extracts channel_id and timestamp correctly."""
def test_standard_url(self) -> None:
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
results = extract_slack_message_urls(query)
assert len(results) == 1
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
def test_multiple_urls(self) -> None:
query = (
"compare https://co.slack.com/archives/C111/p1234567890123456 "
"and https://co.slack.com/archives/C222/p9876543210987654"
)
results = extract_slack_message_urls(query)
assert len(results) == 2
assert results[0] == ("C111", "1234567890.123456")
assert results[1] == ("C222", "9876543210.987654")
def test_no_urls(self) -> None:
query = "what happened in #general last week?"
results = extract_slack_message_urls(query)
assert len(results) == 0
def test_non_slack_url_ignored(self) -> None:
query = "check https://google.com/archives/C111/p1234567890123456"
results = extract_slack_message_urls(query)
assert len(results) == 0
def test_timestamp_conversion(self) -> None:
"""p prefix removed, dot inserted after 10th digit."""
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
results = extract_slack_message_urls(query)
channel_id, ts = results[0]
assert channel_id == "CABC123"
assert ts == "1775491616.524769"
assert not ts.startswith("p")
assert "." in ts
class TestFetchThreadFromUrl:
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
@patch("onyx.context.search.federated.slack_search._build_thread_text")
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_successful_fetch(
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
) -> None:
mock_client = MagicMock()
mock_webclient_cls.return_value = mock_client
# Mock conversations_replies
mock_response = MagicMock()
mock_response.get.return_value = [
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
]
mock_client.conversations_replies.return_value = mock_response
# Mock channel info
mock_ch_response = MagicMock()
mock_ch_response.get.return_value = {"name": "general"}
mock_client.conversations_info.return_value = mock_ch_response
mock_build_thread.return_value = (
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
)
fetch = DirectThreadFetch(
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
)
result = _fetch_thread_from_url(fetch, "xoxp-token")
assert len(result.messages) == 1
msg = result.messages[0]
assert msg.channel_id == "C097NBWMY8Y"
assert msg.thread_id is None # Prevents double-enrichment
assert msg.slack_score == 100000.0
assert "parent" in msg.text
mock_client.conversations_replies.assert_called_once_with(
channel="C097NBWMY8Y", ts="1775491616.524769"
)
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
from slack_sdk.errors import SlackApiError
mock_client = MagicMock()
mock_webclient_cls.return_value = mock_client
mock_client.conversations_replies.side_effect = SlackApiError(
message="channel_not_found",
response=MagicMock(status_code=404),
)
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
result = _fetch_thread_from_url(fetch, "xoxp-token")
assert len(result.messages) == 0

View File

@@ -0,0 +1,225 @@
"""Tests for get_chat_sessions_by_user filtering behavior.
Verifies that failed chat sessions (those with only SYSTEM messages) are
correctly filtered out while preserving recently created sessions, matching
the behavior specified in PR #7233.
"""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from uuid import UUID
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.models import ChatSession
def _make_session(
user_id: UUID,
time_created: datetime | None = None,
time_updated: datetime | None = None,
description: str = "",
) -> MagicMock:
"""Create a mock ChatSession with the given attributes."""
session = MagicMock(spec=ChatSession)
session.id = uuid4()
session.user_id = user_id
session.time_created = time_created or datetime.now(timezone.utc)
session.time_updated = time_updated or session.time_created
session.description = description
session.deleted = False
session.onyxbot_flow = False
session.project_id = None
return session
@pytest.fixture
def user_id() -> UUID:
return uuid4()
@pytest.fixture
def old_time() -> datetime:
"""A timestamp well outside the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(hours=1)
@pytest.fixture
def recent_time() -> datetime:
"""A timestamp within the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(minutes=2)
class TestGetChatSessionsByUser:
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
def test_filters_out_failed_sessions(
self, user_id: UUID, old_time: datetime
) -> None:
"""Sessions with only SYSTEM messages should be excluded."""
valid_session = _make_session(user_id, time_created=old_time)
failed_session = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
# First execute: returns all sessions
# Second execute: returns only the valid session's ID (has non-system msgs)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
valid_session,
failed_session,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == valid_session.id
def test_keeps_recent_sessions_without_messages(
self, user_id: UUID, recent_time: datetime
) -> None:
"""Recently created sessions should be kept even without messages."""
recent_session = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [recent_session]
db_session.execute.side_effect = [mock_result_1]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == recent_session.id
# Should only have been called once — no second query needed
# because the recent session is within the leeway window
assert db_session.execute.call_count == 1
def test_include_failed_chats_skips_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""When include_failed_chats=True, no filtering should occur."""
session_a = _make_session(user_id, time_created=old_time)
session_b = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=True,
)
assert len(result) == 2
# Only one DB call — no second query for message validation
assert db_session.execute.call_count == 1
def test_limit_applied_after_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""Limit should be applied after filtering, not before."""
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
valid_ids = [s.id for s in sessions[:3]]
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = sessions
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = valid_ids
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
limit=2,
)
assert len(result) == 2
# Should be the first 2 valid sessions (order preserved)
assert result[0].id == sessions[0].id
assert result[1].id == sessions[1].id
def test_mixed_recent_and_old_sessions(
self, user_id: UUID, old_time: datetime, recent_time: datetime
) -> None:
"""Mix of recent and old sessions should filter correctly."""
old_valid = _make_session(user_id, time_created=old_time)
old_failed = _make_session(user_id, time_created=old_time)
recent_no_msgs = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
old_valid,
old_failed,
recent_no_msgs,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
result_ids = {cs.id for cs in result}
assert old_valid.id in result_ids
assert recent_no_msgs.id in result_ids
assert old_failed.id not in result_ids
def test_empty_result(self, user_id: UUID) -> None:
"""No sessions should return empty list without errors."""
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert result == []
assert db_session.execute.call_count == 1

View File

@@ -0,0 +1,76 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer <1083d595b1>
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 42
>>
stream
,N<><6~<7E>)<29><><EFBFBD><EFBFBD><EFBFBD>u<EFBFBD> <0C><><EFBFBD>Zc'<27><>>8g<38><67><EFBFBD>n<EFBFBD><6E><EFBFBD><EFBFBD><EFBFBD>9"
endstream
endobj
6 0 obj
<<
/V 2
/R 3
/Length 128
/P 4294967292
/Filter /Standard
/O <6a340a292629053da84a6d8b19a5d505953b8b3fdac3d2d389fde0e354528d44>
/U <d6f0dc91c7b9de264a8d708515468e6528bf4e5e4e758a4164004e56fffa0108>
>>
endobj
xref
0 7
0000000000 65535 f
0000000015 00000 n
0000000059 00000 n
0000000118 00000 n
0000000167 00000 n
0000000348 00000 n
0000000440 00000 n
trailer
<<
/Size 7
/Root 3 0 R
/Info 1 0 R
/ID [ <6364336635356135633239323638353039306635656133623165313637366430> <6364336635356135633239323638353039306635656133623165313637366430> ]
/Encrypt 6 0 R
>>
startxref
655
%%EOF

View File

@@ -12,6 +12,10 @@ dependency on pypdf internals (pypdf.generic).
from io import BytesIO
from pathlib import Path
import pytest
from onyx.file_processing import extract_file_text
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.file_processing.extract_file_text import pdf_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.password_validation import is_pdf_protected
@@ -54,6 +58,12 @@ class TestReadPdfFile:
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
assert text == ""
def test_owner_password_only_pdf_extracts_text(self) -> None:
"""A PDF encrypted with only an owner password (no user password)
should still yield its text content. Regression for #9754."""
text, _, _ = read_pdf_file(_load("owner_protected.pdf"))
assert "Hello World" in text
def test_empty_pdf(self) -> None:
text, _, _ = read_pdf_file(_load("empty.pdf"))
assert text.strip() == ""
@@ -90,6 +100,80 @@ class TestReadPdfFile:
# Returned list is empty when callback is used
assert images == []
def test_image_cap_skips_images_above_limit(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When the embedded-image cap is exceeded, remaining images are skipped.
The cap protects the user-file-processing worker from OOMing on PDFs
with thousands of embedded images. Setting the cap to 0 should yield
zero extracted images even though the fixture has one.
"""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
assert images == []
def test_image_cap_at_limit_extracts_up_to_cap(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A cap >= image count behaves identically to the uncapped path."""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 100)
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
assert len(images) == 1
def test_image_cap_with_callback_stops_streaming_at_limit(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The cap also short-circuits the streaming callback path."""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
collected: list[tuple[bytes, str]] = []
def callback(data: bytes, name: str) -> None:
collected.append((data, name))
read_pdf_file(
_load("with_image.pdf"), extract_images=True, image_callback=callback
)
assert collected == []
# ── count_pdf_embedded_images ────────────────────────────────────────────
class TestCountPdfEmbeddedImages:
def test_returns_count_for_normal_pdf(self) -> None:
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=10) == 1
def test_short_circuits_above_cap(self) -> None:
# with_image.pdf has 1 image. cap=0 means "anything > 0 is over cap" —
# function returns on first increment as the over-cap sentinel.
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=0) == 1
def test_returns_zero_for_pdf_without_images(self) -> None:
assert count_pdf_embedded_images(_load("simple.pdf"), cap=10) == 0
def test_returns_zero_for_invalid_pdf(self) -> None:
assert count_pdf_embedded_images(BytesIO(b"not a pdf"), cap=10) == 0
def test_returns_zero_for_password_locked_pdf(self) -> None:
# encrypted.pdf has an open password; we can't inspect without it, so
# the helper returns 0 — callers rely on the password-protected check
# that runs earlier in the upload pipeline.
assert count_pdf_embedded_images(_load("encrypted.pdf"), cap=10) == 0
def test_inspects_owner_password_only_pdf(self) -> None:
# owner_protected.pdf is encrypted but has no open password. It should
# decrypt with an empty string and count images normally. The fixture
# has zero images, so 0 is a real count (not the "bail on encrypted"
# path).
assert count_pdf_embedded_images(_load("owner_protected.pdf"), cap=10) == 0
def test_preserves_file_position(self) -> None:
pdf = _load("with_image.pdf")
pdf.seek(42)
count_pdf_embedded_images(pdf, cap=10)
assert pdf.tell() == 42
# ── pdf_to_text ──────────────────────────────────────────────────────────
@@ -117,6 +201,12 @@ class TestIsPdfProtected:
def test_protected_pdf(self) -> None:
assert is_pdf_protected(_load("encrypted.pdf")) is True
def test_owner_password_only_is_not_protected(self) -> None:
"""A PDF with only an owner password (permission restrictions) but no
user password should NOT be considered protected — any viewer can open
it without prompting for a password."""
assert is_pdf_protected(_load("owner_protected.pdf")) is False
def test_preserves_file_position(self) -> None:
pdf = _load("simple.pdf")
pdf.seek(42)

View File

@@ -0,0 +1,79 @@
import io
from pptx import Presentation # type: ignore[import-untyped]
from pptx.chart.data import CategoryChartData # type: ignore[import-untyped]
from pptx.enum.chart import XL_CHART_TYPE # type: ignore[import-untyped]
from pptx.util import Inches # type: ignore[import-untyped]
from onyx.file_processing.extract_file_text import pptx_to_text
def _make_pptx_with_chart() -> io.BytesIO:
"""Create an in-memory pptx with one text slide and one chart slide."""
prs = Presentation()
# Slide 1: text only
slide1 = prs.slides.add_slide(prs.slide_layouts[1])
slide1.shapes.title.text = "Introduction"
slide1.placeholders[1].text = "This is the first slide."
# Slide 2: chart
slide2 = prs.slides.add_slide(prs.slide_layouts[5]) # Blank layout
chart_data = CategoryChartData()
chart_data.categories = ["Q1", "Q2", "Q3"]
chart_data.add_series("Revenue", (100, 200, 300))
slide2.shapes.add_chart(
XL_CHART_TYPE.COLUMN_CLUSTERED,
Inches(1),
Inches(1),
Inches(6),
Inches(4),
chart_data,
)
buf = io.BytesIO()
prs.save(buf)
buf.seek(0)
return buf
def _make_pptx_without_chart() -> io.BytesIO:
"""Create an in-memory pptx with a single text-only slide."""
prs = Presentation()
slide = prs.slides.add_slide(prs.slide_layouts[1])
slide.shapes.title.text = "Hello World"
slide.placeholders[1].text = "Some content here."
buf = io.BytesIO()
prs.save(buf)
buf.seek(0)
return buf
class TestPptxToText:
def test_chart_is_omitted(self) -> None:
# Precondition
pptx_file = _make_pptx_with_chart()
# Under test
result = pptx_to_text(pptx_file)
# Postcondition
assert "Introduction" in result
assert "first slide" in result
assert "[chart omitted]" in result
# The actual chart data should NOT appear in the output.
assert "Revenue" not in result
assert "Q1" not in result
def test_text_only_pptx(self) -> None:
# Precondition
pptx_file = _make_pptx_without_chart()
# Under test
result = pptx_to_text(pptx_file)
# Postcondition
assert "Hello World" in result
assert "Some content" in result
assert "[chart omitted]" not in result

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
@@ -9,7 +10,9 @@ from uuid import uuid4
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.server.scim.api import _check_seat_availability
from ee.onyx.server.scim.api import _scim_name_to_str
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
from ee.onyx.server.scim.api import create_user
from ee.onyx.server.scim.api import delete_user
from ee.onyx.server.scim.api import get_user
@@ -741,3 +744,80 @@ class TestEmailCasePreservation:
resource = parse_scim_user(result)
assert resource.userName == "Alice@Example.COM"
assert resource.emails[0].value == "Alice@Example.COM"
class TestSeatLock:
"""Tests for the advisory lock in _check_seat_availability."""
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
def test_acquires_advisory_lock_before_checking(
self,
_mock_tenant: MagicMock,
mock_dal: MagicMock,
) -> None:
"""The advisory lock must be acquired before the seat check runs."""
call_order: list[str] = []
def track_execute(stmt: Any, _params: Any = None) -> None:
if "pg_advisory_xact_lock" in str(stmt):
call_order.append("lock")
mock_dal.session.execute.side_effect = track_execute
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
) as mock_fetch:
mock_result = MagicMock()
mock_result.available = True
mock_fn = MagicMock(return_value=mock_result)
mock_fetch.return_value = mock_fn
def track_check(*_args: Any, **_kwargs: Any) -> Any:
call_order.append("check")
return mock_result
mock_fn.side_effect = track_check
_check_seat_availability(mock_dal)
assert call_order == ["lock", "check"]
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
def test_lock_uses_tenant_scoped_key(
self,
_mock_tenant: MagicMock,
mock_dal: MagicMock,
) -> None:
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
mock_result = MagicMock()
mock_result.available = True
mock_check = MagicMock(return_value=mock_result)
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
return_value=mock_check,
):
_check_seat_availability(mock_dal)
mock_dal.session.execute.assert_called_once()
params = mock_dal.session.execute.call_args[0][1]
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
"""Lock id must be deterministic and differ across tenants."""
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
def test_no_lock_when_ee_absent(
self,
mock_dal: MagicMock,
) -> None:
"""No advisory lock should be acquired when the EE check is absent."""
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
return_value=None,
):
result = _check_seat_availability(mock_dal)
assert result is None
mock_dal.session.execute.assert_not_called()

View File

@@ -4,23 +4,13 @@ from unittest.mock import MagicMock
import pytest
from fastapi import UploadFile
from onyx.natural_language_processing import utils as nlp_utils
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.server.features.projects import projects_file_utils as utils
from onyx.server.settings.models import Settings
class _Tokenizer(BaseTokenizer):
class _Tokenizer:
def encode(self, text: str) -> list[int]:
return [1] * len(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
class _NonSeekableFile(BytesIO):
def tell(self) -> int:
@@ -39,26 +29,10 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
return UploadFile(filename=filename, file=BytesIO(content), size=None)
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
return Settings(
user_file_max_upload_size_mb=upload_size_mb,
file_token_count_threshold_k=token_threshold_k,
)
def _patch_common_dependencies(
monkeypatch: pytest.MonkeyPatch,
upload_size_mb: int = 1,
token_threshold_k: int = 100,
) -> None:
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
monkeypatch.setattr(
utils,
"load_settings",
lambda: _make_settings(upload_size_mb, token_threshold_k),
)
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
@@ -102,8 +76,9 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
def test_categorize_uploaded_files_accepts_size_under_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("small.png", size=99)
@@ -116,7 +91,9 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload_no_size("small.png", content=b"x" * 99)
@@ -129,11 +106,12 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
def test_categorize_uploaded_files_accepts_size_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
upload = _make_upload("edge.png", size=1048576)
upload = _make_upload("edge.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
@@ -143,10 +121,12 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
upload = _make_upload("large.png", size=101)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -157,11 +137,13 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
small = _make_upload("small.png", size=50)
large = _make_upload("large.png", size=1048577)
large = _make_upload("large.png", size=101)
result = utils.categorize_uploaded_files([small, large], MagicMock())
@@ -171,12 +153,15 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_enforces_size_limit_always(
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
upload = _make_upload("oversized.pdf", size=1048577)
upload = _make_upload("oversized.pdf", size=101)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -187,12 +172,14 @@ def test_categorize_uploaded_files_enforces_size_limit_always(
def test_categorize_uploaded_files_checks_size_before_text_extraction(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
extract_mock = MagicMock(return_value="this should not run")
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
oversized_doc = _make_upload("oversized.pdf", size=1048577)
oversized_doc = _make_upload("oversized.pdf", size=101)
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
extract_mock.assert_not_called()
@@ -201,219 +188,40 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
def test_categorize_uploaded_files_accepts_python_file(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A positive upload_size_mb is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
upload = _make_upload("huge.png", size=1048577, content=b"x")
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A positive token_threshold_k is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
upload = _make_upload("big_image.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
def test_categorize_no_token_limit_when_threshold_k_is_zero(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
py_source = b'def hello():\n print("world")\n'
monkeypatch.setattr(
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
)
upload = _make_upload("huge_image.png", size=100)
upload = _make_upload("script.py", size=len(py_source), content=py_source)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 0
assert len(result.acceptable) == 1
assert result.acceptable[0].filename == "script.py"
assert len(result.rejected) == 0
def test_categorize_both_limits_enforced(
def test_categorize_uploaded_files_rejects_binary_file(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Both positive limits are enforced; file exceeding token limit is rejected."""
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
upload = _make_upload("over_tokens.png", size=100)
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
binary_content = bytes(range(256)) * 4
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 5K token limit"
def test_categorize_rejection_reason_contains_dynamic_values(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
# File within size limit but over token limit
upload = _make_upload("tokens.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert result.rejected[0].reason == "Exceeds 7K token limit"
# File over size limit
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
# --- count_tokens tests ---
def test_count_tokens_small_text() -> None:
"""Small text should be encoded in a single call and return correct count."""
tokenizer = _Tokenizer()
text = "hello world"
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_chunked_matches_single_call() -> None:
"""Chunked encoding should produce the same result as single-call for small text."""
tokenizer = _Tokenizer()
text = "a" * 1000
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
# _Tokenizer returns 1 token per char, so total should be 250
assert count_tokens(text, tokenizer) == 250
def test_count_tokens_with_token_limit_exits_early(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set and exceeded, count_tokens should stop early."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
tokenizer = _CountingTokenizer()
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
text = "a" * 500
result = count_tokens(text, tokenizer, token_limit=150)
assert result == 200 # 2 chunks × 100 tokens each
assert encode_call_count == 2, "Should have stopped after 2 chunks"
def test_count_tokens_with_token_limit_not_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set but not exceeded, all chunks are encoded."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
result = count_tokens(text, tokenizer, token_limit=1000)
assert result == 250
def test_count_tokens_no_limit_encodes_all_chunks(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Without token_limit, all chunks are encoded regardless of count."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 500
result = count_tokens(text, tokenizer)
assert result == 500
# --- early exit via token_limit in categorize tests ---
def test_categorize_early_exits_tokenization_for_large_text(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Large text files should be rejected via early-exit tokenization
without encoding all chunks."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
# Let's use a bigger text
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
large_text = "x" * 5000 # 5000 tokens, threshold 1000
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 1
assert "token limit" in result.rejected[0].reason
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
assert (
encode_call_count < 50
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
def test_categorize_text_under_token_limit_accepted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Text files under the token threshold should be accepted with exact count."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
small_text = "x" * 500 # 500 tokens < 1000 threshold
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable_file_to_token_count["ok.txt"] == 500
assert result.rejected[0].filename == "data.bin"
assert "Unsupported file type" in result.rejected[0].reason

View File

@@ -1,23 +1,12 @@
import pytest
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings import store as settings_store
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
class _FakeKvStore:
def __init__(self, data: dict | None = None) -> None:
self._data = data
def load(self, _key: str) -> dict:
if self._data is None:
raise KvKeyNotFoundError()
return self._data
raise KvKeyNotFoundError()
class _FakeCache:
@@ -31,140 +20,13 @@ class _FakeCache:
self._vals[key] = value.encode("utf-8")
def test_load_settings_uses_model_defaults_when_no_stored_value(
def test_load_settings_includes_user_file_max_upload_size_mb(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When no settings are stored (vector DB enabled), load_settings() should
resolve the default token threshold to 200."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled and no settings are stored, the token
threshold should default to 10000 (10M tokens)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
)
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled but admin explicitly set a token threshold,
that value should be preserved (not overridden by the 10000 default)."""
stored = Settings(file_token_count_threshold_k=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 500
def test_load_settings_preserves_zero_token_threshold(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 means 'no limit' and should be preserved."""
stored = Settings(file_token_count_threshold_k=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 0
def test_load_settings_resolves_zero_upload_size_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
def test_load_settings_clamps_upload_size_to_env_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be clamped to the env-configured maximum."""
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 250
def test_load_settings_preserves_upload_size_within_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be preserved unchanged."""
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 150
def test_load_settings_zero_upload_size_resolves_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default,
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 100
def test_load_settings_default_clamped_to_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
the effective default should be min(DEFAULT, MAX)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 50
assert settings.user_file_max_upload_size_mb == 77

1
cli/.gitignore vendored
View File

@@ -1,4 +1,3 @@
onyx-cli
cli
onyx.cli
__pycache__

View File

@@ -63,31 +63,6 @@ onyx-cli agents
onyx-cli agents --json
```
### Serve over SSH
```shell
# Start a public SSH endpoint for the CLI TUI
onyx-cli serve --host 0.0.0.0 --port 2222
# Connect as a client
ssh your-host -p 2222
```
Clients can either:
- paste an API key at the login prompt, or
- skip the prompt by sending `ONYX_API_KEY` over SSH:
```shell
export ONYX_API_KEY=your-key
ssh -o SendEnv=ONYX_API_KEY your-host -p 2222
```
Useful hardening flags:
- `--idle-timeout` (default `15m`)
- `--max-session-timeout` (default `8h`)
- `--rate-limit-per-minute` (default `20`)
- `--rate-limit-burst` (default `40`)
## Commands
| Command | Description |
@@ -95,7 +70,6 @@ Useful hardening flags:
| `chat` | Launch the interactive chat TUI (default) |
| `ask` | Ask a one-shot question (non-interactive) |
| `agents` | List available agents |
| `serve` | Serve the interactive chat TUI over SSH |
| `configure` | Configure server URL and API key |
| `validate-config` | Validate configuration and test connection |

View File

@@ -1,17 +1,7 @@
// Package cmd implements Cobra CLI commands for the Onyx CLI.
package cmd
import (
"context"
"fmt"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/version"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
import "github.com/spf13/cobra"
// Version and Commit are set via ldflags at build time.
var (
@@ -26,69 +16,15 @@ func fullVersion() string {
return Version
}
func printVersion(cmd *cobra.Command) {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Client version: %s\n", fullVersion())
cfg := config.Load()
if !cfg.IsConfigured() {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (not configured)\n")
return
}
client := api.NewClient(cfg)
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer cancel()
log.Debug("fetching backend version from /api/version")
backendVersion, err := client.GetBackendVersion(ctx)
if err != nil {
log.WithError(err).Debug("could not fetch backend version")
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (could not reach server)\n")
return
}
if backendVersion == "" {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (empty response)\n")
return
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: %s\n", backendVersion)
min := version.MinServer()
if sv, ok := version.Parse(backendVersion); ok && sv.LessThan(min) {
log.Warnf("Server version %s is below minimum required %d.%d, please upgrade",
backendVersion, min.Major, min.Minor)
}
}
// Execute creates and runs the root command.
func Execute() error {
opts := struct {
Debug bool
}{}
rootCmd := &cobra.Command{
Use: "onyx-cli",
Short: "Terminal UI for chatting with Onyx",
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
PersistentPreRun: func(cmd *cobra.Command, args []string) {
if opts.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.InfoLevel)
}
log.SetFormatter(&log.TextFormatter{
DisableTimestamp: true,
})
},
Use: "onyx-cli",
Short: "Terminal UI for chatting with Onyx",
Long: "Onyx CLI — a terminal interface for chatting with your Onyx agent.",
Version: fullVersion(),
}
rootCmd.PersistentFlags().BoolVar(&opts.Debug, "debug", false, "run in debug mode")
// Custom --version flag instead of Cobra's built-in (which only shows one version string)
var showVersion bool
rootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "Print client and server version information")
// Register subcommands
chatCmd := newChatCmd()
rootCmd.AddCommand(chatCmd)
@@ -96,16 +32,9 @@ func Execute() error {
rootCmd.AddCommand(newAgentsCmd())
rootCmd.AddCommand(newConfigureCmd())
rootCmd.AddCommand(newValidateConfigCmd())
rootCmd.AddCommand(newServeCmd())
// Default command is chat, but intercept --version first
rootCmd.RunE = func(cmd *cobra.Command, args []string) error {
if showVersion {
printVersion(cmd)
return nil
}
return chatCmd.RunE(cmd, args)
}
// Default command is chat
rootCmd.RunE = chatCmd.RunE
return rootCmd.Execute()
}

View File

@@ -1,450 +0,0 @@
package cmd
import (
"context"
"errors"
"fmt"
"net"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/log"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
"github.com/charmbracelet/wish/activeterm"
"github.com/charmbracelet/wish/bubbletea"
"github.com/charmbracelet/wish/logging"
"github.com/charmbracelet/wish/ratelimiter"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/tui"
"github.com/spf13/cobra"
"golang.org/x/time/rate"
)
const (
defaultServeIdleTimeout = 15 * time.Minute
defaultServeMaxSessionTimeout = 8 * time.Hour
defaultServeRateLimitPerMinute = 20
defaultServeRateLimitBurst = 40
defaultServeRateLimitCacheSize = 4096
maxAPIKeyLength = 512
apiKeyValidationTimeout = 15 * time.Second
maxAPIKeyRetries = 5
)
func sessionEnv(s ssh.Session, key string) string {
prefix := key + "="
for _, env := range s.Environ() {
if strings.HasPrefix(env, prefix) {
return env[len(prefix):]
}
}
return ""
}
func validateAPIKey(serverURL string, apiKey string) error {
trimmedKey := strings.TrimSpace(apiKey)
if len(trimmedKey) > maxAPIKeyLength {
return fmt.Errorf("API key is too long (max %d characters)", maxAPIKeyLength)
}
cfg := config.OnyxCliConfig{
ServerURL: serverURL,
APIKey: trimmedKey,
}
client := api.NewClient(cfg)
ctx, cancel := context.WithTimeout(context.Background(), apiKeyValidationTimeout)
defer cancel()
return client.TestConnection(ctx)
}
// --- auth prompt (bubbletea model) ---
type authState int
const (
authInput authState = iota
authValidating
authDone
)
type authValidatedMsg struct {
key string
err error
}
type authModel struct {
input textinput.Model
serverURL string
state authState
apiKey string // set on successful validation
errMsg string
retries int
aborted bool
}
func newAuthModel(serverURL, initialErr string) authModel {
ti := textinput.New()
ti.Prompt = " API Key: "
ti.EchoMode = textinput.EchoPassword
ti.EchoCharacter = '•'
ti.CharLimit = maxAPIKeyLength
ti.Width = 80
ti.Focus()
return authModel{
input: ti,
serverURL: serverURL,
errMsg: initialErr,
}
}
func (m authModel) Update(msg tea.Msg) (authModel, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.input.Width = max(msg.Width-14, 20) // account for prompt width
return m, nil
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC, tea.KeyCtrlD:
m.aborted = true
return m, nil
default:
if m.state == authValidating {
return m, nil
}
}
switch msg.Type {
case tea.KeyEnter:
key := strings.TrimSpace(m.input.Value())
if key == "" {
m.errMsg = "No key entered."
m.retries++
if m.retries >= maxAPIKeyRetries {
m.errMsg = "Too many failed attempts. Disconnecting."
m.aborted = true
return m, nil
}
m.input.SetValue("")
return m, nil
}
m.state = authValidating
m.errMsg = ""
serverURL := m.serverURL
return m, func() tea.Msg {
return authValidatedMsg{key: key, err: validateAPIKey(serverURL, key)}
}
}
case authValidatedMsg:
if msg.err != nil {
m.state = authInput
m.errMsg = msg.err.Error()
m.retries++
if m.retries >= maxAPIKeyRetries {
m.errMsg = "Too many failed attempts. Disconnecting."
m.aborted = true
return m, nil
}
m.input.SetValue("")
return m, m.input.Focus()
}
m.apiKey = msg.key
m.state = authDone
return m, nil
}
if m.state == authInput {
var cmd tea.Cmd
m.input, cmd = m.input.Update(msg)
return m, cmd
}
return m, nil
}
func (m authModel) View() string {
settingsURL := strings.TrimRight(m.serverURL, "/") + "/app/settings/accounts-access"
var b strings.Builder
b.WriteString("\n")
b.WriteString(" \x1b[1;35mOnyx CLI\x1b[0m\n")
b.WriteString(" \x1b[90m" + m.serverURL + "\x1b[0m\n")
b.WriteString("\n")
b.WriteString(" Generate an API key at:\n")
b.WriteString(" \x1b[4;34m" + settingsURL + "\x1b[0m\n")
b.WriteString("\n")
b.WriteString(" \x1b[90mTip: skip this prompt by passing your key via SSH:\x1b[0m\n")
b.WriteString(" \x1b[90m export ONYX_API_KEY=<key>\x1b[0m\n")
b.WriteString(" \x1b[90m ssh -o SendEnv=ONYX_API_KEY <host> -p <port>\x1b[0m\n")
b.WriteString("\n")
if m.errMsg != "" {
b.WriteString(" \x1b[1;31m" + m.errMsg + "\x1b[0m\n\n")
}
switch m.state {
case authDone:
b.WriteString(" \x1b[32mAuthenticated.\x1b[0m\n")
case authValidating:
b.WriteString(" \x1b[90mValidating…\x1b[0m\n")
default:
b.WriteString(m.input.View() + "\n")
}
return b.String()
}
// --- serve model (wraps auth → TUI in a single bubbletea program) ---
type serveModel struct {
auth authModel
tui tea.Model
authed bool
serverCfg config.OnyxCliConfig
width int
height int
}
func newServeModel(serverCfg config.OnyxCliConfig, initialErr string) serveModel {
return serveModel{
auth: newAuthModel(serverCfg.ServerURL, initialErr),
serverCfg: serverCfg,
}
}
func (m serveModel) Init() tea.Cmd {
return textinput.Blink
}
func (m serveModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if !m.authed {
if ws, ok := msg.(tea.WindowSizeMsg); ok {
m.width = ws.Width
m.height = ws.Height
}
var cmd tea.Cmd
m.auth, cmd = m.auth.Update(msg)
if m.auth.aborted {
return m, tea.Quit
}
if m.auth.apiKey != "" {
cfg := config.OnyxCliConfig{
ServerURL: m.serverCfg.ServerURL,
APIKey: m.auth.apiKey,
DefaultAgentID: m.serverCfg.DefaultAgentID,
}
m.tui = tui.NewModel(cfg)
m.authed = true
w, h := m.width, m.height
return m, tea.Batch(
tea.EnterAltScreen,
tea.EnableMouseCellMotion,
m.tui.Init(),
func() tea.Msg { return tea.WindowSizeMsg{Width: w, Height: h} },
)
}
return m, cmd
}
var cmd tea.Cmd
m.tui, cmd = m.tui.Update(msg)
return m, cmd
}
func (m serveModel) View() string {
if !m.authed {
return m.auth.View()
}
return m.tui.View()
}
// --- serve command ---
func newServeCmd() *cobra.Command {
var (
host string
port int
keyPath string
idleTimeout time.Duration
maxSessionTimeout time.Duration
rateLimitPerMin int
rateLimitBurst int
rateLimitCache int
)
cmd := &cobra.Command{
Use: "serve",
Short: "Serve the Onyx TUI over SSH",
Long: `Start an SSH server that presents the interactive Onyx chat TUI to
connecting clients. Each SSH session gets its own independent TUI instance.
Clients are prompted for their Onyx API key on connect. The key can also be
provided via the ONYX_API_KEY environment variable to skip the prompt:
ssh -o SendEnv=ONYX_API_KEY host -p port
The server URL is taken from the server operator's config. The server
auto-generates an Ed25519 host key on first run if the key file does not
already exist. The host key path can also be set via the ONYX_SSH_HOST_KEY
environment variable (the --host-key flag takes precedence).
Example:
onyx-cli serve --port 2222
ssh localhost -p 2222`,
RunE: func(cmd *cobra.Command, args []string) error {
serverCfg := config.Load()
if serverCfg.ServerURL == "" {
return fmt.Errorf("server URL is not configured; run 'onyx-cli configure' first")
}
if !cmd.Flags().Changed("host-key") {
if v := os.Getenv(config.EnvSSHHostKey); v != "" {
keyPath = v
}
}
if rateLimitPerMin <= 0 {
return fmt.Errorf("--rate-limit-per-minute must be > 0")
}
if rateLimitBurst <= 0 {
return fmt.Errorf("--rate-limit-burst must be > 0")
}
if rateLimitCache <= 0 {
return fmt.Errorf("--rate-limit-cache must be > 0")
}
addr := net.JoinHostPort(host, fmt.Sprintf("%d", port))
connectionLimiter := ratelimiter.NewRateLimiter(
rate.Limit(float64(rateLimitPerMin)/60.0),
rateLimitBurst,
rateLimitCache,
)
handler := func(s ssh.Session) (tea.Model, []tea.ProgramOption) {
apiKey := strings.TrimSpace(sessionEnv(s, config.EnvAPIKey))
var envErr string
if apiKey != "" {
if err := validateAPIKey(serverCfg.ServerURL, apiKey); err != nil {
envErr = fmt.Sprintf("ONYX_API_KEY from SSH environment is invalid: %s", err.Error())
apiKey = ""
}
}
if apiKey != "" {
// Env key is valid — go straight to the TUI.
cfg := config.OnyxCliConfig{
ServerURL: serverCfg.ServerURL,
APIKey: apiKey,
DefaultAgentID: serverCfg.DefaultAgentID,
}
return tui.NewModel(cfg), []tea.ProgramOption{
tea.WithAltScreen(),
tea.WithMouseCellMotion(),
}
}
// No valid env key — show auth prompt, then transition
// to the TUI within the same bubbletea program.
return newServeModel(serverCfg, envErr), []tea.ProgramOption{
tea.WithMouseCellMotion(),
}
}
serverOptions := []ssh.Option{
wish.WithAddress(addr),
wish.WithHostKeyPath(keyPath),
wish.WithMiddleware(
bubbletea.Middleware(handler),
activeterm.Middleware(),
ratelimiter.Middleware(connectionLimiter),
logging.Middleware(),
),
}
if idleTimeout > 0 {
serverOptions = append(serverOptions, wish.WithIdleTimeout(idleTimeout))
}
if maxSessionTimeout > 0 {
serverOptions = append(serverOptions, wish.WithMaxTimeout(maxSessionTimeout))
}
s, err := wish.NewServer(serverOptions...)
if err != nil {
return fmt.Errorf("could not create SSH server: %w", err)
}
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGTERM)
log.Info("Starting Onyx SSH server", "addr", addr)
log.Info("Connect with", "cmd", fmt.Sprintf("ssh %s -p %d", host, port))
errCh := make(chan error, 1)
go func() {
if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Error("SSH server failed", "error", err)
errCh <- err
}
}()
var serverErr error
select {
case <-done:
case serverErr = <-errCh:
}
signal.Stop(done)
log.Info("Shutting down SSH server")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if shutdownErr := s.Shutdown(ctx); shutdownErr != nil {
return errors.Join(serverErr, shutdownErr)
}
return serverErr
},
}
cmd.Flags().StringVar(&host, "host", "localhost", "Host address to bind to")
cmd.Flags().IntVarP(&port, "port", "p", 2222, "Port to listen on")
cmd.Flags().StringVar(&keyPath, "host-key", filepath.Join(config.ConfigDir(), "host_ed25519"),
"Path to SSH host key (auto-generated if missing)")
cmd.Flags().DurationVar(
&idleTimeout,
"idle-timeout",
defaultServeIdleTimeout,
"Disconnect idle clients after this duration (set 0 to disable)",
)
cmd.Flags().DurationVar(
&maxSessionTimeout,
"max-session-timeout",
defaultServeMaxSessionTimeout,
"Maximum lifetime of a client session (set 0 to disable)",
)
cmd.Flags().IntVar(
&rateLimitPerMin,
"rate-limit-per-minute",
defaultServeRateLimitPerMinute,
"Per-IP connection rate limit (new sessions per minute)",
)
cmd.Flags().IntVar(
&rateLimitBurst,
"rate-limit-burst",
defaultServeRateLimitBurst,
"Per-IP burst limit for connection attempts",
)
cmd.Flags().IntVar(
&rateLimitCache,
"rate-limit-cache",
defaultServeRateLimitCacheSize,
"Maximum number of IP limiter entries tracked in memory",
)
return cmd
}

View File

@@ -1,14 +1,10 @@
package cmd
import (
"context"
"fmt"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/version"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
@@ -39,25 +35,6 @@ func newValidateConfigCmd() *cobra.Command {
}
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
// Check backend version compatibility
vCtx, vCancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer vCancel()
backendVersion, err := client.GetBackendVersion(vCtx)
if err != nil {
log.WithError(err).Debug("could not fetch backend version")
} else if backendVersion == "" {
log.Debug("server returned empty version string")
} else {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Version: %s\n", backendVersion)
min := version.MinServer()
if sv, ok := version.Parse(backendVersion); ok && sv.LessThan(min) {
log.Warnf("Server version %s is below minimum required %d.%d, please upgrade",
backendVersion, min.Major, min.Minor)
}
}
return nil
},
}

View File

@@ -1,63 +1,45 @@
module github.com/onyx-dot-app/onyx/cli
go 1.26.1
go 1.26.0
require (
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/glamour v1.0.0
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
github.com/charmbracelet/log v1.0.0
github.com/charmbracelet/ssh v0.0.0-20250826160808-ebfa259c7309
github.com/charmbracelet/wish v1.4.7
github.com/sirupsen/logrus v1.9.4
github.com/spf13/cobra v1.10.2
golang.org/x/term v0.41.0
golang.org/x/text v0.35.0
golang.org/x/time v0.15.0
github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.3.4
github.com/charmbracelet/glamour v0.8.0
github.com/charmbracelet/lipgloss v1.1.0
github.com/spf13/cobra v1.9.1
golang.org/x/term v0.30.0
golang.org/x/text v0.34.0
)
require (
github.com/alecthomas/chroma/v2 v2.23.1 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/charmbracelet/colorprofile v0.4.3 // indirect
github.com/charmbracelet/keygen v0.5.4 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/conpty v0.2.0 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca // indirect
github.com/charmbracelet/x/input v0.3.7 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.11.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/creack/pty v1.1.24 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.8.0 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/go-logfmt/logfmt v0.6.1 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.21 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yuin/goldmark v1.8.2 // indirect
github.com/yuin/goldmark-emoji v1.0.6 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sys v0.42.0 // indirect
github.com/yuin/goldmark v1.7.4 // indirect
github.com/yuin/goldmark-emoji v1.0.3 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.31.0 // indirect
)

View File

@@ -1,89 +1,55 @@
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.23.1 h1:nv2AVZdTyClGbVQkIzlDm/rnhk1E9bU9nXwmZ/Vk/iY=
github.com/alecthomas/chroma/v2 v2.23.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY=
github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E=
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex9t5KX76i20Q=
github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q=
github.com/charmbracelet/glamour v1.0.0 h1:AWMLOVFHTsysl4WV8T8QgkQ0s/ZNZo7CiE4WKhk8l08=
github.com/charmbracelet/glamour v1.0.0/go.mod h1:DSdohgOBkMr2ZQNhw4LZxSGpx3SvpeujNoXrQyH2hxo=
github.com/charmbracelet/keygen v0.5.4 h1:XQYgf6UEaTGgQSSmiPpIQ78WfseNQp4Pz8N/c1OsrdA=
github.com/charmbracelet/keygen v0.5.4/go.mod h1:t4oBRr41bvK7FaJsAaAQhhkUuHslzFXVjOBwA55CZNM=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdRc4=
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
github.com/charmbracelet/ssh v0.0.0-20250826160808-ebfa259c7309 h1:dCVbCRRtg9+tsfiTXTp0WupDlHruAXyp+YoxGVofHHc=
github.com/charmbracelet/ssh v0.0.0-20250826160808-ebfa259c7309/go.mod h1:R9cISUs5kAH4Cq/rguNbSwcR+slE5Dfm8FEs//uoIGE=
github.com/charmbracelet/wish v1.4.7 h1:O+jdLac3s6GaqkOHHSwezejNK04vl6VjO1A+hl8J8Yc=
github.com/charmbracelet/wish v1.4.7/go.mod h1:OBZ8vC62JC5cvbxJLh+bIWtG7Ctmct+ewziuUWK+G14=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/conpty v0.2.0 h1:eKtA2hm34qNfgJCDp/M6Dc0gLy7e07YEK4qAdNGOvVY=
github.com/charmbracelet/x/conpty v0.2.0/go.mod h1:fexgUnVrZgw8scD49f6VSi0Ggj9GWYIrpedRthAwW/8=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca h1:QQoyQLgUzojMNWHVHToN6d9qTvT0KWtxUKIRPx/Ox5o=
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
github.com/charmbracelet/x/input v0.3.7 h1:UzVbkt1vgM9dBQ+K+uRolBlN6IF2oLchmPKKo/aucXo=
github.com/charmbracelet/x/input v0.3.7/go.mod h1:ZSS9Cia6Cycf2T6ToKIOxeTBTDwl25AGwArJuGaOBH8=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY=
github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo=
github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM=
github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k=
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
github.com/charmbracelet/bubbletea v1.3.4/go.mod h1:dtcUCyCGEX3g9tosuYiut3MXgY/Jsv9nKVdibKKRRXo=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs=
github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
@@ -94,47 +60,35 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA=
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4=
github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -34,7 +34,8 @@ class CustomBuildHook(BuildHookInterface):
# Build the Go binary (always rebuild to ensure correct version injection)
if not os.path.exists(binary_name):
print(f"Building Go binary '{binary_name}'...")
ldflags = f"-X main.version={tag} -X main.commit={commit} -s -w"
pkg = "github.com/onyx-dot-app/onyx/cli/cmd"
ldflags = f"-X {pkg}.version={tag}" f" -X {pkg}.commit={commit}" " -s -w"
subprocess.check_call( # noqa: S603
["go", "build", f"-ldflags={ldflags}", "-o", binary_name],
)

View File

@@ -270,17 +270,6 @@ func (c *Client) UploadFile(ctx context.Context, filePath string) (*models.FileD
}, nil
}
// GetBackendVersion fetches the backend version string from /api/version.
func (c *Client) GetBackendVersion(ctx context.Context) (string, error) {
var resp struct {
BackendVersion string `json:"backend_version"`
}
if err := c.doJSON(ctx, "GET", "/api/version", nil, &resp); err != nil {
return "", err
}
return resp.BackendVersion, nil
}
// StopChatSession sends a stop signal for a streaming session (best-effort).
func (c *Client) StopChatSession(ctx context.Context, sessionID string) {
req, err := c.newRequest(ctx, "POST", "/api/chat/stop-chat-session/"+sessionID, nil)

View File

@@ -9,10 +9,9 @@ import (
)
const (
EnvServerURL = "ONYX_SERVER_URL"
EnvAPIKey = "ONYX_API_KEY"
EnvServerURL = "ONYX_SERVER_URL"
EnvAPIKey = "ONYX_API_KEY"
EnvAgentID = "ONYX_PERSONA_ID"
EnvSSHHostKey = "ONYX_SSH_HOST_KEY"
)
// OnyxCliConfig holds the CLI configuration.
@@ -36,8 +35,8 @@ func (c OnyxCliConfig) IsConfigured() bool {
return c.APIKey != ""
}
// ConfigDir returns ~/.config/onyx-cli
func ConfigDir() string {
// configDir returns ~/.config/onyx-cli
func configDir() string {
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
return filepath.Join(xdg, "onyx-cli")
}
@@ -50,7 +49,7 @@ func ConfigDir() string {
// ConfigFilePath returns the full path to the config file.
func ConfigFilePath() string {
return filepath.Join(ConfigDir(), "config.json")
return filepath.Join(configDir(), "config.json")
}
// ConfigExists checks if the config file exists on disk.
@@ -88,7 +87,7 @@ func Load() OnyxCliConfig {
// Save writes the config to disk, creating parent directories if needed.
func Save(cfg OnyxCliConfig) error {
dir := ConfigDir()
dir := configDir()
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}

View File

@@ -1,58 +0,0 @@
// Package version provides semver parsing and compatibility checks.
package version
import (
"strconv"
"strings"
)
// Semver holds parsed semantic version components.
type Semver struct {
Major int
Minor int
Patch int
}
// minServer is the minimum backend version required by this CLI.
var minServer = Semver{Major: 3, Minor: 0, Patch: 0}
// MinServer returns the minimum backend version required by this CLI.
func MinServer() Semver { return minServer }
// Parse extracts major, minor, patch from a version string like "3.1.2" or "v3.1.2".
// Returns ok=false if the string is not valid semver.
func Parse(v string) (Semver, bool) {
v = strings.TrimPrefix(v, "v")
// Strip any pre-release suffix (e.g. "-beta.1") and build metadata (e.g. "+build.1")
if idx := strings.IndexAny(v, "-+"); idx != -1 {
v = v[:idx]
}
parts := strings.SplitN(v, ".", 3)
if len(parts) != 3 {
return Semver{}, false
}
major, err := strconv.Atoi(parts[0])
if err != nil {
return Semver{}, false
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return Semver{}, false
}
patch, err := strconv.Atoi(parts[2])
if err != nil {
return Semver{}, false
}
return Semver{Major: major, Minor: minor, Patch: patch}, true
}
// LessThan reports whether s is strictly less than other.
func (s Semver) LessThan(other Semver) bool {
if s.Major != other.Major {
return s.Major < other.Major
}
if s.Minor != other.Minor {
return s.Minor < other.Minor
}
return s.Patch < other.Patch
}

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["hatchling==1.29.0", "go-bin~=1.26.1", "manygo==0.2.0"]
requires = ["hatchling", "go-bin~=1.24.11", "manygo"]
build-backend = "hatchling.build"
[project]

View File

@@ -39,22 +39,6 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -39,20 +39,6 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -39,23 +39,6 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -66,3 +66,10 @@ DB_READONLY_PASSWORD=password
# Show extra/uncommon connectors
# See https://docs.onyx.app/admins/connectors/overview for a full list of connectors
SHOW_EXTRA_CONNECTORS=False
# User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
#SKIP_USERFILE_THRESHOLD=false
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
#SKIP_USERFILE_THRESHOLD_TENANT_IDS=

View File

@@ -35,10 +35,6 @@ USER_AUTH_SECRET=""
## Chat Configuration
# HARD_DELETE_CHATS=
# MAX_ALLOWED_UPLOAD_SIZE_MB=250
# Default per-user upload size limit (MB) when no admin value is set.
# Automatically clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at runtime.
# DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=100
## Base URL for redirects
# WEB_DOMAIN=
@@ -46,6 +42,13 @@ USER_AUTH_SECRET=""
## Enterprise Features, requires a paid plan and licenses
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false
## User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
# SKIP_USERFILE_THRESHOLD=false
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
# SKIP_USERFILE_THRESHOLD_TENANT_IDS=
################################################################################
## SERVICES CONFIGURATIONS

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.38
version: 0.4.36
appVersion: latest
annotations:
category: Productivity

View File

@@ -1,26 +0,0 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docfetching.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9092
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -73,10 +73,6 @@ spec:
"-Q",
"connector_doc_fetching",
]
ports:
- name: metrics
containerPort: 9092
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_docfetching.resources | nindent 12 }}
envFrom:

View File

@@ -1,26 +0,0 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docprocessing.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9093
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -73,10 +73,6 @@ spec:
"-Q",
"docprocessing",
]
ports:
- name: metrics
containerPort: 9093
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_docprocessing.resources | nindent 12 }}
envFrom:

View File

@@ -1,26 +0,0 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9096
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -70,10 +70,6 @@ spec:
"-Q",
"monitoring",
]
ports:
- name: metrics
containerPort: 9096
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_monitoring.resources | nindent 12 }}
envFrom:

View File

@@ -63,22 +63,6 @@ data:
}
{{- end }}
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
# timeout settings
proxy_connect_timeout {{ .Values.nginx.timeouts.connect }}s;
proxy_send_timeout {{ .Values.nginx.timeouts.send }}s;
proxy_read_timeout {{ .Values.nginx.timeouts.read }}s;
proxy_pass http://api_server;
}
location ~ ^/(api|openapi\.json)(/.*)?$ {
rewrite ^/api(/.*)$ $1 break;
proxy_set_header X-Real-IP $remote_addr;

View File

@@ -282,7 +282,7 @@ nginx:
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
# Workaround: Helm upgrade will restart if the following annotation value changes.
podAnnotations:
onyx.app/nginx-config-version: "3"
onyx.app/nginx-config-version: "2"
# Propagate DOMAIN into nginx so server_name continues to use the same env var
extraEnvs:
@@ -1285,5 +1285,11 @@ configMap:
DOMAIN: "localhost"
# Chat Configs
HARD_DELETE_CHATS: ""
MAX_ALLOWED_UPLOAD_SIZE_MB: ""
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB: ""
# User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
SKIP_USERFILE_THRESHOLD: ""
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
# Maximum user upload file size in MB for chat/projects uploads
USER_FILE_MAX_UPLOAD_SIZE_MB: ""

View File

@@ -66,10 +66,14 @@ backend = [
"jsonref==1.1.0",
"kubernetes==31.0.0",
"trafilatura==1.12.2",
"langchain-core==1.2.22",
"langchain-core==1.2.11",
"lazy_imports==1.0.1",
"lxml==5.3.0",
"Mako==1.2.4",
# NOTE: Do not update without understanding the patching behavior in
# get_markitdown_converter in
# backend/onyx/file_processing/extract_file_text.py and what impacts
# updating might have on this behavior.
"markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2",
"mcp[cli]==1.26.0",
"msal==1.34.0",
@@ -144,7 +148,7 @@ dev = [
"matplotlib==3.10.8",
"mypy-extensions==1.0.0",
"mypy==1.13.0",
"onyx-devtools==0.7.2",
"onyx-devtools==0.7.1",
"openapi-generator-cli==7.17.0",
"pandas-stubs~=2.3.3",
"pre-commit==3.2.2",

View File

@@ -28,7 +28,7 @@ Some commands require external tools to be installed and configured:
- **uv** - Required for `backend` commands
- Install from [docs.astral.sh/uv](https://docs.astral.sh/uv/)
- **GitHub CLI** (`gh`) - Required for `run-ci`, `cherry-pick`, and `trace` commands
- **GitHub CLI** (`gh`) - Required for `run-ci` and `cherry-pick` commands
- Install from [cli.github.com](https://cli.github.com/)
- Authenticate with `gh auth login`
@@ -412,62 +412,6 @@ The `compare` subcommand writes a `summary.json` alongside the report with aggre
counts (changed, added, removed, unchanged). The HTML report is only generated when
visual differences are detected.
### `trace` - View Playwright Traces from CI
Download Playwright trace artifacts from a GitHub Actions run and open them
with `playwright show-trace`. Traces are only generated for failing tests
(`retain-on-failure`).
```shell
ods trace [run-id-or-url]
```
The run can be specified as a numeric run ID, a full GitHub Actions URL, or
omitted to find the latest Playwright run for the current branch.
**Flags:**
| Flag | Default | Description |
|------|---------|-------------|
| `--branch`, `-b` | | Find latest run for this branch |
| `--pr` | | Find latest run for this PR number |
| `--project`, `-p` | | Filter to a specific project (`admin`, `exclusive`, `lite`) |
| `--list`, `-l` | `false` | List available traces without opening |
| `--no-open` | `false` | Download traces but don't open them |
When multiple traces are found, an interactive picker lets you select which
traces to open. Use arrow keys or `j`/`k` to navigate, `space` to toggle,
`a` to select all, `n` to deselect all, and `enter` to open. Falls back to a
plain-text prompt when no TTY is available.
Downloaded artifacts are cached in `/tmp/ods-traces/<run-id>/` so repeated
invocations for the same run are instant.
**Examples:**
```shell
# Latest run for the current branch
ods trace
# Specific run ID
ods trace 12345678
# Full GitHub Actions URL
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
# Latest run for a PR
ods trace --pr 9500
# Latest run for a specific branch
ods trace --branch main
# Only download admin project traces
ods trace --project admin
# List traces without opening
ods trace --list
```
### Testing Changes Locally (Dry Run)
Both `run-ci` and `cherry-pick` support `--dry-run` to test without making remote changes:

View File

@@ -55,7 +55,6 @@ func NewRootCommand() *cobra.Command {
cmd.AddCommand(NewWebCommand())
cmd.AddCommand(NewLatestStableTagCommand())
cmd.AddCommand(NewWhoisCommand())
cmd.AddCommand(NewTraceCommand())
return cmd
}

View File

@@ -1,556 +0,0 @@
package cmd
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/onyx-dot-app/onyx/tools/ods/internal/git"
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
"github.com/onyx-dot-app/onyx/tools/ods/internal/tui"
)
const playwrightWorkflow = "Run Playwright Tests"
// TraceOptions holds options for the trace command
type TraceOptions struct {
Branch string
PR string
Project string
List bool
NoOpen bool
}
// traceInfo describes a single trace.zip found in the downloaded artifacts.
type traceInfo struct {
Path string // absolute path to trace.zip
Project string // project group extracted from artifact dir (e.g. "admin", "admin-shard-1")
TestDir string // test directory name (human-readable-ish)
}
// NewTraceCommand creates a new trace command
func NewTraceCommand() *cobra.Command {
opts := &TraceOptions{}
cmd := &cobra.Command{
Use: "trace [run-id-or-url]",
Short: "Download and view Playwright traces from GitHub Actions",
Long: `Download Playwright trace artifacts from a GitHub Actions run and open them
with 'playwright show-trace'.
The run can be specified as:
- A GitHub Actions run ID (numeric)
- A full GitHub Actions run URL
- Omitted, to find the latest Playwright run for the current branch
You can also look up the latest run by branch name or PR number.
Examples:
ods trace # latest run for current branch
ods trace 12345678 # specific run ID
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
ods trace --pr 9500 # latest run for PR #9500
ods trace --branch main # latest run for main branch
ods trace --project admin # only download admin project traces
ods trace --list # list available traces without opening`,
Args: cobra.MaximumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
runTrace(args, opts)
},
}
cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Find latest run for this branch")
cmd.Flags().StringVar(&opts.PR, "pr", "", "Find latest run for this PR number")
cmd.Flags().StringVarP(&opts.Project, "project", "p", "", "Filter to a specific project (admin, exclusive, lite)")
cmd.Flags().BoolVarP(&opts.List, "list", "l", false, "List available traces without opening")
cmd.Flags().BoolVar(&opts.NoOpen, "no-open", false, "Download traces but don't open them")
return cmd
}
// ghRun represents a GitHub Actions workflow run from `gh run list`
type ghRun struct {
DatabaseID int64 `json:"databaseId"`
Status string `json:"status"`
Conclusion string `json:"conclusion"`
HeadBranch string `json:"headBranch"`
URL string `json:"url"`
}
func runTrace(args []string, opts *TraceOptions) {
git.CheckGitHubCLI()
runID, err := resolveRunID(args, opts)
if err != nil {
log.Fatalf("Failed to resolve run: %v", err)
}
log.Infof("Using run ID: %s", runID)
destDir, err := downloadTraceArtifacts(runID, opts.Project)
if err != nil {
log.Fatalf("Failed to download artifacts: %v", err)
}
traces, err := findTraceInfos(destDir, runID)
if err != nil {
log.Fatalf("Failed to find traces: %v", err)
}
if len(traces) == 0 {
log.Info("No trace files found in the downloaded artifacts.")
log.Info("Traces are only generated for failing tests (retain-on-failure).")
return
}
projects := groupByProject(traces)
if opts.List || opts.NoOpen {
printTraceList(traces, projects)
fmt.Printf("\nTraces downloaded to: %s\n", destDir)
return
}
if len(traces) == 1 {
openTraces(traces)
return
}
for {
selected := selectTraces(traces, projects)
if len(selected) == 0 {
return
}
openTraces(selected)
}
}
// resolveRunID determines the run ID from the provided arguments and options.
func resolveRunID(args []string, opts *TraceOptions) (string, error) {
if len(args) == 1 {
return parseRunIDFromArg(args[0])
}
if opts.PR != "" {
return findLatestRunForPR(opts.PR)
}
branch := opts.Branch
if branch == "" {
var err error
branch, err = git.GetCurrentBranch()
if err != nil {
return "", fmt.Errorf("failed to get current branch: %w", err)
}
if branch == "" {
return "", fmt.Errorf("detached HEAD; specify a --branch, --pr, or run ID")
}
log.Infof("Using current branch: %s", branch)
}
return findLatestRunForBranch(branch)
}
var runURLPattern = regexp.MustCompile(`/actions/runs/(\d+)`)
// parseRunIDFromArg extracts a run ID from either a numeric string or a full URL.
func parseRunIDFromArg(arg string) (string, error) {
if matched, _ := regexp.MatchString(`^\d+$`, arg); matched {
return arg, nil
}
matches := runURLPattern.FindStringSubmatch(arg)
if matches != nil {
return matches[1], nil
}
return "", fmt.Errorf("could not parse run ID from %q; expected a numeric ID or GitHub Actions URL", arg)
}
// findLatestRunForBranch finds the most recent Playwright workflow run for a branch.
func findLatestRunForBranch(branch string) (string, error) {
log.Infof("Looking up latest Playwright run for branch: %s", branch)
cmd := exec.Command("gh", "run", "list",
"--workflow", playwrightWorkflow,
"--branch", branch,
"--limit", "1",
"--json", "databaseId,status,conclusion,headBranch,url",
)
output, err := cmd.Output()
if err != nil {
return "", ghError(err, "gh run list failed")
}
var runs []ghRun
if err := json.Unmarshal(output, &runs); err != nil {
return "", fmt.Errorf("failed to parse run list: %w", err)
}
if len(runs) == 0 {
return "", fmt.Errorf("no Playwright runs found for branch %q", branch)
}
run := runs[0]
log.Infof("Found run: %s (status: %s, conclusion: %s)", run.URL, run.Status, run.Conclusion)
return fmt.Sprintf("%d", run.DatabaseID), nil
}
// findLatestRunForPR finds the most recent Playwright workflow run for a PR.
func findLatestRunForPR(prNumber string) (string, error) {
log.Infof("Looking up branch for PR #%s", prNumber)
cmd := exec.Command("gh", "pr", "view", prNumber,
"--json", "headRefName",
"--jq", ".headRefName",
)
output, err := cmd.Output()
if err != nil {
return "", ghError(err, "gh pr view failed")
}
branch := strings.TrimSpace(string(output))
if branch == "" {
return "", fmt.Errorf("could not determine branch for PR #%s", prNumber)
}
log.Infof("PR #%s is on branch: %s", prNumber, branch)
return findLatestRunForBranch(branch)
}
// downloadTraceArtifacts downloads playwright trace artifacts for a run.
// Returns the path to the download directory.
func downloadTraceArtifacts(runID string, project string) (string, error) {
cacheKey := runID
if project != "" {
cacheKey = runID + "-" + project
}
destDir := filepath.Join(os.TempDir(), "ods-traces", cacheKey)
// Reuse a previous download if traces exist
if info, err := os.Stat(destDir); err == nil && info.IsDir() {
traces, _ := findTraces(destDir)
if len(traces) > 0 {
log.Infof("Using cached download at %s", destDir)
return destDir, nil
}
_ = os.RemoveAll(destDir)
}
if err := os.MkdirAll(destDir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory %s: %w", destDir, err)
}
ghArgs := []string{"run", "download", runID, "--dir", destDir}
if project != "" {
ghArgs = append(ghArgs, "--pattern", fmt.Sprintf("playwright-test-results-%s-*", project))
} else {
ghArgs = append(ghArgs, "--pattern", "playwright-test-results-*")
}
log.Infof("Downloading trace artifacts...")
log.Debugf("Running: gh %s", strings.Join(ghArgs, " "))
cmd := exec.Command("gh", ghArgs...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
_ = os.RemoveAll(destDir)
return "", fmt.Errorf("gh run download failed: %w\nMake sure the run ID is correct and the artifacts haven't expired (30 day retention)", err)
}
return destDir, nil
}
// findTraces recursively finds all trace.zip files under a directory.
func findTraces(root string) ([]string, error) {
var traces []string
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && info.Name() == "trace.zip" {
traces = append(traces, path)
}
return nil
})
return traces, err
}
// findTraceInfos walks the download directory and returns structured trace info.
// Expects: destDir/{artifact-dir}/{test-dir}/trace.zip
func findTraceInfos(destDir, runID string) ([]traceInfo, error) {
var traces []traceInfo
err := filepath.Walk(destDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || info.Name() != "trace.zip" {
return nil
}
rel, _ := filepath.Rel(destDir, path)
parts := strings.SplitN(rel, string(filepath.Separator), 3)
artifactDir := ""
testDir := filepath.Base(filepath.Dir(path))
if len(parts) >= 2 {
artifactDir = parts[0]
testDir = parts[1]
}
traces = append(traces, traceInfo{
Path: path,
Project: extractProject(artifactDir, runID),
TestDir: testDir,
})
return nil
})
sort.Slice(traces, func(i, j int) bool {
pi, pj := projectSortKey(traces[i].Project), projectSortKey(traces[j].Project)
if pi != pj {
return pi < pj
}
return traces[i].TestDir < traces[j].TestDir
})
return traces, err
}
// extractProject derives a project group from an artifact directory name.
// e.g. "playwright-test-results-admin-12345" -> "admin"
//
// "playwright-test-results-admin-shard-1-12345" -> "admin-shard-1"
func extractProject(artifactDir, runID string) string {
name := strings.TrimPrefix(artifactDir, "playwright-test-results-")
name = strings.TrimSuffix(name, "-"+runID)
if name == "" {
return artifactDir
}
return name
}
// projectSortKey returns a sort-friendly key that orders admin < exclusive < lite.
func projectSortKey(project string) string {
switch {
case strings.HasPrefix(project, "admin"):
return "0-" + project
case strings.HasPrefix(project, "exclusive"):
return "1-" + project
case strings.HasPrefix(project, "lite"):
return "2-" + project
default:
return "3-" + project
}
}
// groupByProject returns an ordered list of unique project names found in traces.
func groupByProject(traces []traceInfo) []string {
seen := map[string]bool{}
var projects []string
for _, t := range traces {
if !seen[t.Project] {
seen[t.Project] = true
projects = append(projects, t.Project)
}
}
sort.Slice(projects, func(i, j int) bool {
return projectSortKey(projects[i]) < projectSortKey(projects[j])
})
return projects
}
// printTraceList displays traces grouped by project.
func printTraceList(traces []traceInfo, projects []string) {
fmt.Printf("\nFound %d trace(s) across %d project(s):\n", len(traces), len(projects))
idx := 1
for _, proj := range projects {
count := 0
for _, t := range traces {
if t.Project == proj {
count++
}
}
fmt.Printf("\n %s (%d):\n", proj, count)
for _, t := range traces {
if t.Project == proj {
fmt.Printf(" [%2d] %s\n", idx, t.TestDir)
idx++
}
}
}
}
// selectTraces tries the TUI picker first, falling back to a plain-text
// prompt when the terminal cannot be initialised (e.g. piped output).
func selectTraces(traces []traceInfo, projects []string) []traceInfo {
// Build picker groups in the same order as the sorted traces slice.
var groups []tui.PickerGroup
for _, proj := range projects {
var items []string
for _, t := range traces {
if t.Project == proj {
items = append(items, t.TestDir)
}
}
groups = append(groups, tui.PickerGroup{Label: proj, Items: items})
}
indices, err := tui.Pick(groups)
if err != nil {
// Terminal not available — fall back to text prompt
log.Debugf("TUI picker unavailable: %v", err)
printTraceList(traces, projects)
return promptTraceSelection(traces, projects)
}
if indices == nil {
return nil // user cancelled
}
selected := make([]traceInfo, len(indices))
for i, idx := range indices {
selected[i] = traces[idx]
}
return selected
}
// promptTraceSelection asks the user which traces to open via plain text.
// Accepts numbers (1,3,5), ranges (1-5), "all", or a project name.
func promptTraceSelection(traces []traceInfo, projects []string) []traceInfo {
fmt.Printf("\nOpen which traces? (e.g. 1,3,5 | 1-5 | all | %s): ", strings.Join(projects, " | "))
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
log.Fatalf("Failed to read input: %v", err)
}
input = strings.TrimSpace(input)
if input == "" || strings.EqualFold(input, "all") {
return traces
}
// Check if input matches a project name
for _, proj := range projects {
if strings.EqualFold(input, proj) {
var selected []traceInfo
for _, t := range traces {
if t.Project == proj {
selected = append(selected, t)
}
}
return selected
}
}
// Parse as number/range selection
indices := parseTraceSelection(input, len(traces))
if len(indices) == 0 {
log.Warn("No valid selection; opening all traces")
return traces
}
selected := make([]traceInfo, len(indices))
for i, idx := range indices {
selected[i] = traces[idx]
}
return selected
}
// parseTraceSelection parses a comma-separated list of numbers and ranges into
// 0-based indices. Input is 1-indexed (matches display). Out-of-range values
// are silently ignored.
func parseTraceSelection(input string, max int) []int {
var result []int
seen := map[int]bool{}
for _, part := range strings.Split(input, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if idx := strings.Index(part, "-"); idx > 0 {
lo, err1 := strconv.Atoi(strings.TrimSpace(part[:idx]))
hi, err2 := strconv.Atoi(strings.TrimSpace(part[idx+1:]))
if err1 != nil || err2 != nil {
continue
}
for i := lo; i <= hi; i++ {
zi := i - 1
if zi >= 0 && zi < max && !seen[zi] {
result = append(result, zi)
seen[zi] = true
}
}
} else {
n, err := strconv.Atoi(part)
if err != nil {
continue
}
zi := n - 1
if zi >= 0 && zi < max && !seen[zi] {
result = append(result, zi)
seen[zi] = true
}
}
}
return result
}
// openTraces opens the selected traces with playwright show-trace,
// running npx from the web/ directory to use the project's Playwright version.
func openTraces(traces []traceInfo) {
tracePaths := make([]string, len(traces))
for i, t := range traces {
tracePaths[i] = t.Path
}
args := append([]string{"playwright", "show-trace"}, tracePaths...)
log.Infof("Opening %d trace(s) with playwright show-trace...", len(traces))
cmd := exec.Command("npx", args...)
// Run from web/ to pick up the locally-installed Playwright version
if root, err := paths.GitRoot(); err == nil {
cmd.Dir = filepath.Join(root, "web")
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = os.Stdin
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
// Normal exit (e.g. user closed the window) — just log and return
// so the picker loop can continue.
log.Debugf("playwright exited with code %d", exitErr.ExitCode())
return
}
log.Errorf("playwright show-trace failed: %v\nMake sure Playwright is installed (npx playwright install)", err)
}
}
// ghError wraps a gh CLI error with stderr output.
func ghError(err error, msg string) error {
if exitErr, ok := err.(*exec.ExitError); ok {
return fmt.Errorf("%s: %w: %s", msg, err, string(exitErr.Stderr))
}
return fmt.Errorf("%s: %w", msg, err)
}

View File

@@ -1,21 +1,15 @@
module github.com/onyx-dot-app/onyx/tools/ods
go 1.26.1
go 1.26.0
require (
github.com/gdamore/tcell/v2 v2.13.8
github.com/jmelahman/tag v0.5.2
github.com/sirupsen/logrus v1.9.4
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.10
)
require (
github.com/gdamore/encoding v1.0.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/term v0.41.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/sys v0.39.0 // indirect
)

Some files were not shown because too many files have changed in this diff Show More