mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-27 18:42:43 +00:00
Compare commits
20 Commits
release/v3
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6326c7f0b9 | ||
|
|
40420fc4e6 | ||
|
|
1a2b6a66cc | ||
|
|
d1b1529ccf | ||
|
|
fedd9c76e5 | ||
|
|
0b34b40b79 | ||
|
|
fe82ddb1b9 | ||
|
|
32d3d70525 | ||
|
|
40b9e10890 | ||
|
|
e21b204b8a | ||
|
|
2f672b3a4f | ||
|
|
cf19d0df4f | ||
|
|
86a6a4c134 | ||
|
|
146b5449d2 | ||
|
|
b66991b5c5 | ||
|
|
9cb76dc027 | ||
|
|
f66891d19e | ||
|
|
c07c952ad5 | ||
|
|
be7f40a28a | ||
|
|
26f941b5da |
3
.github/workflows/helm-chart-releases.yml
vendored
3
.github/workflows/helm-chart-releases.yml
vendored
@@ -47,7 +47,8 @@ jobs:
|
||||
done
|
||||
|
||||
- name: Publish Helm charts to gh-pages
|
||||
uses: stefanprodan/helm-gh-pages@0ad2bb377311d61ac04ad9eb6f252fb68e207260 # ratchet:stefanprodan/helm-gh-pages@v1.7.0
|
||||
# NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43
|
||||
uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
charts_dir: deployment/helm/charts
|
||||
|
||||
@@ -24,6 +24,16 @@ When hardcoding a boolean variable to a constant value, remove the variable enti
|
||||
|
||||
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
|
||||
|
||||
## Nginx Routing — New Backend Routes
|
||||
|
||||
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
|
||||
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
|
||||
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
|
||||
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
|
||||
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
|
||||
|
||||
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
|
||||
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
|
||||
@@ -35,7 +35,7 @@ Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep
|
||||
> [!TIP]
|
||||
> Run Onyx with one command (or see deployment section below):
|
||||
> ```
|
||||
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
|
||||
> curl -fsSL https://onyx.app/install_onyx.sh | bash
|
||||
> ```
|
||||
|
||||
****
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -187,7 +188,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -227,6 +227,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_permission_sync_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
@@ -473,6 +474,8 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -29,6 +29,7 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -162,7 +163,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -221,6 +221,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
@@ -105,9 +106,11 @@ def _get_slack_document_access(
|
||||
slack_connector: SlackConnector,
|
||||
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
indexing_start: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback
|
||||
callback=callback,
|
||||
start=indexing_start,
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
@@ -180,9 +183,15 @@ def slack_doc_sync(
|
||||
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.set_credentials_provider(provider)
|
||||
indexing_start_ts: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
yield from _get_slack_document_access(
|
||||
slack_connector,
|
||||
slack_connector=slack_connector,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
indexing_start=indexing_start_ts,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -40,10 +41,19 @@ def generic_doc_sync(
|
||||
|
||||
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
|
||||
|
||||
indexing_start: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=indexing_start,
|
||||
callback=callback,
|
||||
):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -7,7 +8,59 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
_broker_client: Redis | None = None
|
||||
_broker_url: str | None = None
|
||||
_broker_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def celery_get_broker_client(app: Celery) -> Redis:
|
||||
"""Return a shared Redis client connected to the Celery broker DB.
|
||||
|
||||
Uses a module-level singleton so all tasks on a worker share one
|
||||
connection instead of creating a new one per call. The client
|
||||
connects directly to the broker Redis DB (parsed from the broker URL).
|
||||
|
||||
Thread-safe via lock — safe for use in Celery thread-pool workers.
|
||||
|
||||
Usage:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
length = celery_get_queue_length(queue, r_celery)
|
||||
"""
|
||||
global _broker_client, _broker_url
|
||||
with _broker_client_lock:
|
||||
url = app.conf.broker_url
|
||||
if _broker_client is not None and _broker_url == url:
|
||||
try:
|
||||
_broker_client.ping()
|
||||
return _broker_client
|
||||
except Exception:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
elif _broker_client is not None:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
|
||||
_broker_url = url
|
||||
_broker_client = Redis.from_url(
|
||||
url,
|
||||
decode_responses=False,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _broker_client
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
|
||||
@@ -14,6 +14,7 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -132,7 +133,6 @@ def revoke_tasks_blocking_deletion(
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -149,6 +149,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -449,7 +450,7 @@ def check_indexing_completion(
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(task.app)
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
@@ -19,6 +18,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -698,31 +698,27 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
# Collect queue metrics with broker connection
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_metrics = _collect_queue_metrics(r_celery)
|
||||
|
||||
# Collect and log each metric
|
||||
# Collect remaining metrics (no broker connection needed)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
all_metrics: list[Metric] = queue_metrics
|
||||
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
|
||||
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
for metric in all_metrics:
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -890,7 +886,7 @@ def monitor_celery_queues_helper(
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(task.app)
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -1080,7 +1076,7 @@ def cloud_monitor_celery_pidbox(
|
||||
num_deleted = 0
|
||||
|
||||
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -203,7 +204,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -261,6 +261,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
@@ -105,7 +106,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(celery_app)
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
@@ -238,7 +239,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -591,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
|
||||
@@ -44,6 +44,31 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
|
||||
# Hard ceiling for the admin-configurable file upload size (in MB).
|
||||
# Self-hosted customers can raise or lower this via the environment variable.
|
||||
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
|
||||
if _raw_max_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
|
||||
_raw_max_upload_size_mb,
|
||||
)
|
||||
_raw_max_upload_size_mb = 250
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
|
||||
|
||||
# Default fallback for the per-user file upload size limit (in MB) when no
|
||||
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
|
||||
# runtime so this never silently exceeds the hard ceiling.
|
||||
_raw_default_upload_size_mb = int(
|
||||
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
|
||||
)
|
||||
if _raw_default_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
|
||||
_raw_default_upload_size_mb,
|
||||
)
|
||||
_raw_default_upload_size_mb = 100
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
|
||||
) # 1 day
|
||||
@@ -61,17 +86,6 @@ CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -890,8 +890,8 @@ class ConfluenceConnector(
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
@@ -915,8 +915,8 @@ class ConfluenceConnector(
|
||||
self.confluence_client, doc_id, restrictions, ancestors
|
||||
) or space_level_access_info.get(page_space_key)
|
||||
|
||||
# Query pages
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
# Query pages (with optional time filtering for indexing_start)
|
||||
page_query = self._construct_page_cql_query(start, end)
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -950,7 +950,9 @@ class ConfluenceConnector(
|
||||
|
||||
# Query attachments for each page
|
||||
page_hierarchy_node_yielded = False
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
|
||||
@@ -1765,7 +1765,11 @@ class SharepointConnector(
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
def _fetch_slim_documents_from_sharepoint(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
@@ -1786,7 +1790,9 @@ class SharepointConnector(
|
||||
# Process site documents if flag is True
|
||||
if self.include_site_documents:
|
||||
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
|
||||
site_descriptor=site_descriptor
|
||||
site_descriptor=site_descriptor,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
@@ -1841,7 +1847,9 @@ class SharepointConnector(
|
||||
|
||||
# Process site pages if flag is True
|
||||
if self.include_site_pages:
|
||||
site_pages = self._fetch_site_pages(site_descriptor)
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start, end=end
|
||||
)
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
@@ -2565,12 +2573,22 @@ class SharepointConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
|
||||
yield from self._fetch_slim_documents_from_sharepoint()
|
||||
start_dt = (
|
||||
datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
if start is not None
|
||||
else None
|
||||
)
|
||||
end_dt = (
|
||||
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
|
||||
)
|
||||
yield from self._fetch_slim_documents_from_sharepoint(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -516,6 +516,8 @@ def _get_all_doc_ids(
|
||||
] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
workspace_url: str | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -546,6 +548,8 @@ def _get_all_doc_ids(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
oldest=str(start) if start else None, # 0.0 -> None intentionally
|
||||
latest=str(end) if end is not None else None,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
@@ -847,8 +851,8 @@ class SlackConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
@@ -861,6 +865,8 @@ class SlackConnector(
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
callback=callback,
|
||||
workspace_url=self._workspace_url,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _load_from_checkpoint(
|
||||
|
||||
@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
|
||||
it is done automatically outside of this code.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- chunks: Document chunks with all of the information needed for
|
||||
indexing to the document index.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
This is often a batch operation including chunks from multiple
|
||||
documents.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -351,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -647,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata, # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
Groups chunks by document ID and for each document, deletes existing
|
||||
chunks and indexes the new chunks in bulk.
|
||||
@@ -673,29 +673,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document is newly indexed or had already existed and was just
|
||||
updated.
|
||||
"""
|
||||
# Group chunks by document ID.
|
||||
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
|
||||
list
|
||||
total_chunks = sum(
|
||||
cc.new_chunk_cnt
|
||||
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
|
||||
)
|
||||
for chunk in chunks:
|
||||
doc_id_to_chunks[chunk.source_document.id].append(chunk)
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
|
||||
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
|
||||
f"documents for index {self._index_name}."
|
||||
)
|
||||
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
# Try to index per-document.
|
||||
for _, chunks in doc_id_to_chunks.items():
|
||||
deleted_doc_ids: set[str] = set()
|
||||
# Buffer chunks per document as they arrive from the iterable.
|
||||
# When the document ID changes flush the buffered chunks.
|
||||
current_doc_id: str | None = None
|
||||
current_chunks: list[DocMetadataAwareIndexChunk] = []
|
||||
|
||||
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
|
||||
assert len(doc_chunks) > 0, "doc_chunks is empty"
|
||||
|
||||
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
|
||||
# Do this before deleting existing chunks to reduce the amount of
|
||||
# time the document index has no content for a given document, and
|
||||
# to reduce the chance of entering a state where we delete chunks,
|
||||
# then some error happens, and never successfully index new chunks.
|
||||
# Since we are doing this in batches, an error occurring midway
|
||||
# can result in a state where chunks are deleted and not all the
|
||||
# new chunks have been indexed.
|
||||
chunk_batch: list[DocumentChunk] = [
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk)
|
||||
for chunk in doc_chunks
|
||||
]
|
||||
onyx_document: Document = chunks[0].source_document
|
||||
onyx_document: Document = doc_chunks[0].source_document
|
||||
# First delete the doc's chunks from the index. This is so that
|
||||
# there are no dangling chunks in the index, in the event that the
|
||||
# new document's content contains fewer chunks than the previous
|
||||
@@ -704,22 +709,40 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# if the chunk count has actually decreased. This assumes that
|
||||
# overlapping chunks are perfectly overwritten. If we can't
|
||||
# guarantee that then we need the code as-is.
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed.
|
||||
document_insertion_record = DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
if onyx_document.id not in deleted_doc_ids:
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
deleted_doc_ids.add(onyx_document.id)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed. We record the result before bulk_index_documents
|
||||
# runs. If indexing raises, this entire result list is discarded
|
||||
# by the caller's retry logic, so early recording is safe.
|
||||
document_indexing_results.append(
|
||||
DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
document_indexing_results.append(document_insertion_record)
|
||||
|
||||
for chunk in chunks:
|
||||
doc_id = chunk.source_document.id
|
||||
if doc_id != current_doc_id:
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
current_doc_id = doc_id
|
||||
current_chunks = [chunk]
|
||||
else:
|
||||
current_chunks.append(chunk)
|
||||
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
|
||||
return document_indexing_results
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import time
|
||||
import urllib
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -318,7 +320,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
|
||||
@@ -338,22 +340,31 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
# Vespa has restrictions on valid characters, yet document IDs come from
|
||||
# external w.r.t. this class. We need to sanitize them.
|
||||
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
|
||||
clean_chunk_id_copy(chunk) for chunk in chunks
|
||||
]
|
||||
assert len(cleaned_chunks) == len(
|
||||
chunks
|
||||
), "Bug: Cleaned chunks and input chunks have different lengths."
|
||||
#
|
||||
# Instead of materializing all cleaned chunks upfront, we stream them
|
||||
# through a generator that cleans IDs and builds the original-ID mapping
|
||||
# incrementally as chunks flow into Vespa.
|
||||
def _clean_and_track(
|
||||
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
|
||||
id_map: dict[str, str],
|
||||
seen_ids: set[str],
|
||||
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
|
||||
"""Cleans chunk IDs and builds the original-ID mapping
|
||||
incrementally as chunks flow through, avoiding a separate
|
||||
materialization pass."""
|
||||
for chunk in chunks_iter:
|
||||
original_id = chunk.source_document.id
|
||||
cleaned = clean_chunk_id_copy(chunk)
|
||||
cleaned_id = cleaned.source_document.id
|
||||
# Needed so the final DocumentInsertionRecord returned can have
|
||||
# the original document ID. cleaned_chunks might not contain IDs
|
||||
# exactly as callers supplied them.
|
||||
id_map[cleaned_id] = original_id
|
||||
seen_ids.add(cleaned_id)
|
||||
yield cleaned
|
||||
|
||||
# Needed so the final DocumentInsertionRecord returned can have the
|
||||
# original document ID. cleaned_chunks might not contain IDs exactly as
|
||||
# callers supplied them.
|
||||
new_document_id_to_original_document_id: dict[str, str] = dict()
|
||||
for i, cleaned_chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[i]
|
||||
new_document_id_to_original_document_id[
|
||||
cleaned_chunk.source_document.id
|
||||
] = old_chunk.source_document.id
|
||||
new_document_id_to_original_document_id: dict[str, str] = {}
|
||||
all_cleaned_doc_ids: set[str] = set()
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
@@ -409,7 +420,13 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Insert new Vespa documents.
|
||||
# Insert new Vespa documents, streaming through the cleaning
|
||||
# pipeline so chunks are never fully materialized.
|
||||
cleaned_chunks = _clean_and_track(
|
||||
chunks,
|
||||
new_document_id_to_original_document_id,
|
||||
all_cleaned_doc_ids,
|
||||
)
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
@@ -419,10 +436,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids: set[str] = {
|
||||
chunk.source_document.id for chunk in cleaned_chunks
|
||||
}
|
||||
|
||||
return [
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -173,8 +174,10 @@ class UserFileIndexingAdapter:
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
token_count: int = (
|
||||
count_tokens(combined_content, llm_tokenizer)
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[str(user_file_id)] = token_count
|
||||
else:
|
||||
|
||||
@@ -175,6 +175,32 @@ def get_tokenizer(
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
|
||||
|
||||
# Max characters per encode() call.
|
||||
_ENCODE_CHUNK_SIZE = 500_000
|
||||
|
||||
|
||||
def count_tokens(
|
||||
text: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
token_limit: int | None = None,
|
||||
) -> int:
|
||||
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
|
||||
|
||||
If token_limit is provided and the text is large enough to require
|
||||
multiple chunks (> 500k chars), stops early once the count exceeds it.
|
||||
When early-exiting, the returned value exceeds token_limit but may be
|
||||
less than the true full token count.
|
||||
"""
|
||||
if len(text) <= _ENCODE_CHUNK_SIZE:
|
||||
return len(tokenizer.encode(text))
|
||||
total = 0
|
||||
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
|
||||
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
|
||||
if token_limit is not None and total > token_limit:
|
||||
return total # Already over — skip remaining chunks
|
||||
return total
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
|
||||
@@ -3844,9 +3844,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
|
||||
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
|
||||
"version": "5.0.5",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
|
||||
"integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^4.0.2"
|
||||
@@ -4224,9 +4224,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
|
||||
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -5007,9 +5007,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "1.1.12",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
|
||||
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
|
||||
"version": "1.1.13",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
|
||||
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
||||
@@ -123,9 +123,8 @@ def _validate_endpoint(
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
|
||||
timeout (operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
|
||||
@@ -9,20 +9,15 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -161,8 +156,8 @@ def categorize_uploaded_files(
|
||||
document formats (.pdf, .docx, …) and falls back to a text-detection
|
||||
heuristic for unknown extensions (.py, .js, .rs, …).
|
||||
- Uses default tokenizer to compute token length.
|
||||
- If token length > threshold, reject file (unless threshold skip is enabled).
|
||||
- If text cannot be extracted, reject file.
|
||||
- If token length exceeds the admin-configured threshold, reject file.
|
||||
- If extension unsupported or text cannot be extracted, reject file.
|
||||
- Otherwise marked as acceptable.
|
||||
"""
|
||||
|
||||
@@ -173,36 +168,33 @@ def categorize_uploaded_files(
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
# Check global skip flag (works for both single-tenant and multi-tenant)
|
||||
if SKIP_USERFILE_THRESHOLD:
|
||||
skip_threshold = True
|
||||
logger.info("Skipping userfile threshold check (global setting)")
|
||||
# Check tenant-specific skip list (only applicable in multi-tenant)
|
||||
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
|
||||
try:
|
||||
current_tenant_id = get_current_tenant_id()
|
||||
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
if skip_threshold:
|
||||
logger.info(
|
||||
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to get current tenant ID: {str(e)}")
|
||||
# Derive limits from admin-configurable settings.
|
||||
# For upload size: load_settings() resolves 0/None to a positive default.
|
||||
# For token threshold: 0 means "no limit" (converted to None below).
|
||||
settings = load_settings()
|
||||
max_upload_size_mb = (
|
||||
settings.user_file_max_upload_size_mb
|
||||
) # always positive after load_settings()
|
||||
max_upload_size_bytes = (
|
||||
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
|
||||
)
|
||||
token_threshold_k = settings.file_token_count_threshold_k
|
||||
token_threshold = (
|
||||
token_threshold_k * 1000 if token_threshold_k else None
|
||||
) # 0 → None = no limit
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
# Size limit is a hard safety cap.
|
||||
if max_upload_size_bytes is not None and is_upload_too_large(
|
||||
upload, max_upload_size_bytes
|
||||
):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
@@ -224,11 +216,11 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -269,12 +261,14 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
token_count = len(tokenizer.encode(text_content))
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
token_count = count_tokens(
|
||||
text_content, tokenizer, token_limit=token_threshold
|
||||
)
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -12,7 +12,6 @@ stale, which is fine for monitoring dashboards.
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -104,25 +103,23 @@ class _CachedCollector(Collector):
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
@@ -404,17 +401,19 @@ class RedisHealthCollector(_CachedCollector):
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
|
||||
@@ -3,12 +3,8 @@
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
@@ -21,7 +17,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# state) until configure() / set_celery_app() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
@@ -32,72 +28,15 @@ _worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
|
||||
@@ -9,7 +9,9 @@ from onyx import __version__ as onyx_version
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
@@ -17,10 +19,16 @@ from onyx.db.models import User
|
||||
from onyx.db.notification import dismiss_all_notifications
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Notification
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.models import UserSettings
|
||||
@@ -41,6 +49,15 @@ basic_router = APIRouter(prefix="/settings")
|
||||
def admin_put_settings(
|
||||
settings: Settings, _: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb is not None
|
||||
and settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
|
||||
)
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
@@ -83,6 +100,16 @@ def fetch_settings(
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
),
|
||||
default_file_token_count_threshold_k=(
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,12 +2,19 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -78,7 +85,12 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
user_file_max_upload_size_mb: int | None = Field(
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
@@ -108,3 +120,14 @@ class UserSettings(Settings):
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
|
||||
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
# Factory defaults so the frontend can show a "restore default" button.
|
||||
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
default_file_token_count_threshold_k: int = Field(
|
||||
default_factory=lambda: (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -51,9 +57,36 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
|
||||
# Resolve context-aware defaults for token threshold.
|
||||
# None = admin hasn't set a value yet → use context-aware default.
|
||||
# 0 = admin explicitly chose "no limit" → preserve as-is.
|
||||
if settings.file_token_count_threshold_k is None:
|
||||
settings.file_token_count_threshold_k = (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
# Upload size: 0 and None are treated as "unset" (not "no limit") →
|
||||
# fall back to min(configured default, hard ceiling).
|
||||
if not settings.user_file_max_upload_size_mb:
|
||||
settings.user_file_max_upload_size_mb = min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
)
|
||||
|
||||
# Clamp to env ceiling so stale KV values are capped even if the
|
||||
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
|
||||
# was already saved (api.py only guards new writes).
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -263,7 +263,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.1
|
||||
onyx-devtools==0.7.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -191,25 +191,6 @@ IGNORED_SYNCING_TENANT_LIST = (
|
||||
else None
|
||||
)
|
||||
|
||||
# Global flag to skip userfile threshold for all users/tenants
|
||||
SKIP_USERFILE_THRESHOLD = (
|
||||
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
|
||||
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
|
||||
)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
|
||||
[
|
||||
tenant.strip()
|
||||
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
|
||||
if tenant.strip()
|
||||
]
|
||||
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -17,6 +19,10 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
# Predates any test workspace messages, so the result set should match
|
||||
# the "no start time" case while exercising the oldest= parameter.
|
||||
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@@ -105,15 +111,17 @@ def test_load_from_checkpoint_access__private_channel(
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
|
||||
def test_slim_documents_access__public_channel(
|
||||
slack_connector: SlackConnector,
|
||||
start_ts: float | None,
|
||||
) -> None:
|
||||
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
|
||||
if not slack_connector.client:
|
||||
raise RuntimeError("Web client must be defined")
|
||||
|
||||
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=0.0,
|
||||
start=start_ts,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
@@ -149,7 +157,7 @@ def test_slim_documents_access__private_channel(
|
||||
raise RuntimeError("Web client must be defined")
|
||||
|
||||
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=0.0,
|
||||
start=None,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
@@ -129,6 +129,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -88,10 +88,22 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
|
||||
Also patches ``celery_get_broker_client`` so the mock app doesn't need
|
||||
a real broker URL.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -90,8 +90,17 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
|
||||
from tests.external_dependency_unit.document_index.conftest import make_chunk
|
||||
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
|
||||
assert len(result_map) == 2
|
||||
assert result_map[existing_doc] is True
|
||||
assert result_map[new_doc] is False
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
|
||||
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk(doc_id, chunk_id=i)
|
||||
|
||||
results = document_index.index(
|
||||
chunks=chunk_gen(), indexing_metadata=metadata
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == doc_id
|
||||
assert results[0].already_existed is False
|
||||
|
||||
@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == 0
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk("test_doc_gen", chunk_id=i)
|
||||
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
|
||||
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
|
||||
tenant_id=get_current_tenant_id(),
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
results = document_index.index(chunk_gen(), index_batch_params)
|
||||
|
||||
assert len(results) == 1
|
||||
record = results.pop()
|
||||
assert record.document_id == "test_doc_gen"
|
||||
assert record.already_existed is False
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for celery_get_broker_client singleton."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery import celery_redis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> Iterator[None]:
|
||||
"""Reset the module-level singleton between tests."""
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
yield
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
|
||||
|
||||
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
|
||||
app = MagicMock()
|
||||
app.conf.broker_url = broker_url
|
||||
return app
|
||||
|
||||
|
||||
class TestCeleryGetBrokerClient:
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
result = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert result is mock_client
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://localhost:6379/15"
|
||||
assert call_args[1]["decode_responses"] is False
|
||||
assert call_args[1]["socket_keepalive"] is True
|
||||
assert call_args[1]["retry_on_timeout"] is True
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert client1 is client2
|
||||
# from_url called only once
|
||||
assert mock_redis_cls.from_url.call_count == 1
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
|
||||
stale_client = MagicMock()
|
||||
stale_client.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
fresh_client = MagicMock()
|
||||
fresh_client.ping.return_value = True
|
||||
|
||||
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
|
||||
|
||||
app = _make_mock_app()
|
||||
|
||||
# First call creates stale_client
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
assert client1 is stale_client
|
||||
|
||||
# Second call: ping fails, creates fresh_client
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
assert client2 is fresh_client
|
||||
assert mock_redis_cls.from_url.call_count == 2
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_redis_cls.from_url.return_value = MagicMock()
|
||||
|
||||
app = _make_mock_app("redis://custom-host:6380/3")
|
||||
celery_redis.celery_get_broker_client(app)
|
||||
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://custom-host:6380/3"
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -31,6 +33,7 @@ def mock_jira_cc_pair(
|
||||
"jira_base_url": jira_base_url,
|
||||
"project_key": project_key,
|
||||
}
|
||||
mock_cc_pair.connector.indexing_start = None
|
||||
|
||||
return mock_cc_pair
|
||||
|
||||
@@ -65,3 +68,75 @@ def test_jira_permission_sync(
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
):
|
||||
print(doc)
|
||||
|
||||
|
||||
def test_jira_doc_sync_passes_indexing_start(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
"""Verify that generic_doc_sync derives indexing_start from cc_pair
|
||||
and forwards it to retrieve_all_slim_docs_perm_sync."""
|
||||
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
|
||||
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
|
||||
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
assert jira_connector._jira_client is not None
|
||||
jira_connector._jira_client._options = MagicMock()
|
||||
jira_connector._jira_client._options.return_value = {
|
||||
"rest_api_version": JIRA_SERVER_API_VERSION
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(jira_connector),
|
||||
"retrieve_all_slim_docs_perm_sync",
|
||||
return_value=iter([]),
|
||||
) as mock_retrieve:
|
||||
list(
|
||||
jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
)
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args
|
||||
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
|
||||
|
||||
|
||||
def test_jira_doc_sync_passes_none_when_no_indexing_start(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
"""Verify that indexing_start is None when the connector has no indexing_start set."""
|
||||
mock_jira_cc_pair.connector.indexing_start = None
|
||||
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
assert jira_connector._jira_client is not None
|
||||
jira_connector._jira_client._options = MagicMock()
|
||||
jira_connector._jira_client._options.return_value = {
|
||||
"rest_api_version": JIRA_SERVER_API_VERSION
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(jira_connector),
|
||||
"retrieve_all_slim_docs_perm_sync",
|
||||
return_value=iter([]),
|
||||
) as mock_retrieve:
|
||||
list(
|
||||
jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
)
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args
|
||||
assert call_kwargs.kwargs["start"] is None
|
||||
|
||||
@@ -4,13 +4,23 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.natural_language_processing import utils as nlp_utils
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.server.features.projects import projects_file_utils as utils
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
class _Tokenizer:
|
||||
class _Tokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return [1] * len(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class _NonSeekableFile(BytesIO):
|
||||
def tell(self) -> int:
|
||||
@@ -29,10 +39,26 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
|
||||
return UploadFile(filename=filename, file=BytesIO(content), size=None)
|
||||
|
||||
|
||||
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
|
||||
return Settings(
|
||||
user_file_max_upload_size_mb=upload_size_mb,
|
||||
file_token_count_threshold_k=token_threshold_k,
|
||||
)
|
||||
|
||||
|
||||
def _patch_common_dependencies(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
upload_size_mb: int = 1,
|
||||
token_threshold_k: int = 100,
|
||||
) -> None:
|
||||
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
|
||||
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
|
||||
monkeypatch.setattr(
|
||||
utils,
|
||||
"load_settings",
|
||||
lambda: _make_settings(upload_size_mb, token_threshold_k),
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
|
||||
@@ -76,9 +102,8 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
|
||||
def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("small.png", size=99)
|
||||
@@ -91,9 +116,7 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload_no_size("small.png", content=b"x" * 99)
|
||||
@@ -106,12 +129,11 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("edge.png", size=100)
|
||||
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
|
||||
upload = _make_upload("edge.png", size=1048576)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
@@ -121,12 +143,10 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("large.png", size=101)
|
||||
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
@@ -137,13 +157,11 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
small = _make_upload("small.png", size=50)
|
||||
large = _make_upload("large.png", size=101)
|
||||
large = _make_upload("large.png", size=1048577)
|
||||
|
||||
result = utils.categorize_uploaded_files([small, large], MagicMock())
|
||||
|
||||
@@ -153,15 +171,12 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
|
||||
def test_categorize_uploaded_files_enforces_size_limit_always(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
|
||||
upload = _make_upload("oversized.pdf", size=101)
|
||||
upload = _make_upload("oversized.pdf", size=1048577)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
@@ -172,14 +187,12 @@ def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_sk
|
||||
def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
|
||||
extract_mock = MagicMock(return_value="this should not run")
|
||||
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
|
||||
|
||||
oversized_doc = _make_upload("oversized.pdf", size=101)
|
||||
oversized_doc = _make_upload("oversized.pdf", size=1048577)
|
||||
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
|
||||
|
||||
extract_mock.assert_not_called()
|
||||
@@ -188,40 +201,219 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_python_file(
|
||||
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
"""A positive upload_size_mb is always enforced."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
py_source = b'def hello():\n print("world")\n'
|
||||
monkeypatch.setattr(
|
||||
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
|
||||
)
|
||||
|
||||
upload = _make_upload("script.py", size=len(py_source), content=py_source)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable[0].filename == "script.py"
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_rejects_binary_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
|
||||
|
||||
binary_content = bytes(range(256)) * 4
|
||||
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
|
||||
upload = _make_upload("huge.png", size=1048577, content=b"x")
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].filename == "data.bin"
|
||||
assert "Unsupported file type" in result.rejected[0].reason
|
||||
|
||||
|
||||
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A positive token_threshold_k is always enforced."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
|
||||
|
||||
upload = _make_upload("big_image.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
|
||||
|
||||
def test_categorize_no_token_limit_when_threshold_k_is_zero(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
|
||||
monkeypatch.setattr(
|
||||
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
|
||||
)
|
||||
|
||||
upload = _make_upload("huge_image.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 0
|
||||
assert len(result.acceptable) == 1
|
||||
|
||||
|
||||
def test_categorize_both_limits_enforced(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Both positive limits are enforced; file exceeding token limit is rejected."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
|
||||
|
||||
upload = _make_upload("over_tokens.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 5K token limit"
|
||||
|
||||
|
||||
def test_categorize_rejection_reason_contains_dynamic_values(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
|
||||
|
||||
# File within size limit but over token limit
|
||||
upload = _make_upload("tokens.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert result.rejected[0].reason == "Exceeds 7K token limit"
|
||||
|
||||
# File over size limit
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
|
||||
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
|
||||
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
|
||||
|
||||
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
|
||||
|
||||
|
||||
# --- count_tokens tests ---
|
||||
|
||||
|
||||
def test_count_tokens_small_text() -> None:
|
||||
"""Small text should be encoded in a single call and return correct count."""
|
||||
tokenizer = _Tokenizer()
|
||||
text = "hello world"
|
||||
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
|
||||
|
||||
|
||||
def test_count_tokens_chunked_matches_single_call() -> None:
|
||||
"""Chunked encoding should produce the same result as single-call for small text."""
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 1000
|
||||
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
|
||||
|
||||
|
||||
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 250
|
||||
# _Tokenizer returns 1 token per char, so total should be 250
|
||||
assert count_tokens(text, tokenizer) == 250
|
||||
|
||||
|
||||
def test_count_tokens_with_token_limit_exits_early(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When token_limit is set and exceeded, count_tokens should stop early."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
|
||||
encode_call_count = 0
|
||||
original_tokenizer = _Tokenizer()
|
||||
|
||||
class _CountingTokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
nonlocal encode_call_count
|
||||
encode_call_count += 1
|
||||
return original_tokenizer.encode(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
tokenizer = _CountingTokenizer()
|
||||
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
|
||||
text = "a" * 500
|
||||
result = count_tokens(text, tokenizer, token_limit=150)
|
||||
|
||||
assert result == 200 # 2 chunks × 100 tokens each
|
||||
assert encode_call_count == 2, "Should have stopped after 2 chunks"
|
||||
|
||||
|
||||
def test_count_tokens_with_token_limit_not_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When token_limit is set but not exceeded, all chunks are encoded."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 250
|
||||
result = count_tokens(text, tokenizer, token_limit=1000)
|
||||
assert result == 250
|
||||
|
||||
|
||||
def test_count_tokens_no_limit_encodes_all_chunks(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Without token_limit, all chunks are encoded regardless of count."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 500
|
||||
result = count_tokens(text, tokenizer)
|
||||
assert result == 500
|
||||
|
||||
|
||||
# --- early exit via token_limit in categorize tests ---
|
||||
|
||||
|
||||
def test_categorize_early_exits_tokenization_for_large_text(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Large text files should be rejected via early-exit tokenization
|
||||
without encoding all chunks."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
|
||||
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
|
||||
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
|
||||
# Let's use a bigger text
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
large_text = "x" * 5000 # 5000 tokens, threshold 1000
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
|
||||
|
||||
encode_call_count = 0
|
||||
original_tokenizer = _Tokenizer()
|
||||
|
||||
class _CountingTokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
nonlocal encode_call_count
|
||||
encode_call_count += 1
|
||||
return original_tokenizer.encode(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
|
||||
|
||||
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 1
|
||||
assert "token limit" in result.rejected[0].reason
|
||||
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
|
||||
assert (
|
||||
encode_call_count < 50
|
||||
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
|
||||
|
||||
|
||||
def test_categorize_text_under_token_limit_accepted(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Text files under the token threshold should be accepted with exact count."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
small_text = "x" * 500 # 500 tokens < 1000 threshold
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
|
||||
|
||||
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable_file_to_token_count["ok.txt"] == 500
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings import store as settings_store
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
class _FakeKvStore:
|
||||
def __init__(self, data: dict | None = None) -> None:
|
||||
self._data = data
|
||||
|
||||
def load(self, _key: str) -> dict:
|
||||
raise KvKeyNotFoundError()
|
||||
if self._data is None:
|
||||
raise KvKeyNotFoundError()
|
||||
return self._data
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
@@ -20,13 +31,140 @@ class _FakeCache:
|
||||
self._vals[key] = value.encode("utf-8")
|
||||
|
||||
|
||||
def test_load_settings_includes_user_file_max_upload_size_mb(
|
||||
def test_load_settings_uses_model_defaults_when_no_stored_value(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When no settings are stored (vector DB enabled), load_settings() should
|
||||
resolve the default token threshold to 200."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 77
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
assert (
|
||||
settings.file_token_count_threshold_k
|
||||
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
|
||||
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When vector DB is disabled and no settings are stored, the token
|
||||
threshold should default to 10000 (10M tokens)."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
assert (
|
||||
settings.file_token_count_threshold_k
|
||||
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
)
|
||||
|
||||
|
||||
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When vector DB is disabled but admin explicitly set a token threshold,
|
||||
that value should be preserved (not overridden by the 10000 default)."""
|
||||
stored = Settings(file_token_count_threshold_k=500).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.file_token_count_threshold_k == 500
|
||||
|
||||
|
||||
def test_load_settings_preserves_zero_token_threshold(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 means 'no limit' and should be preserved."""
|
||||
stored = Settings(file_token_count_threshold_k=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.file_token_count_threshold_k == 0
|
||||
|
||||
|
||||
def test_load_settings_resolves_zero_upload_size_to_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 should be treated as unset and resolved to the default."""
|
||||
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
|
||||
|
||||
def test_load_settings_clamps_upload_size_to_env_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
|
||||
be clamped to the env-configured maximum."""
|
||||
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 250
|
||||
|
||||
|
||||
def test_load_settings_preserves_upload_size_within_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
|
||||
be preserved unchanged."""
|
||||
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 150
|
||||
|
||||
|
||||
def test_load_settings_zero_upload_size_resolves_to_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 should be treated as unset and resolved to the default,
|
||||
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
|
||||
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
|
||||
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 100
|
||||
|
||||
|
||||
def test_load_settings_default_clamped_to_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
the effective default should be min(DEFAULT, MAX)."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 50
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -13,6 +14,16 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_broker_client() -> Iterator[None]:
|
||||
"""Patch celery_get_broker_client for all collector tests."""
|
||||
with patch(
|
||||
"onyx.background.celery.celery_redis.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
@@ -24,8 +35,7 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -60,8 +70,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
@@ -74,8 +84,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -98,31 +108,10 @@ class TestQueueDepthCollector:
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
|
||||
@@ -1,96 +1,22 @@
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
"""Tests for indexing pipeline setup."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
|
||||
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
class TestCollectorCeleryAppSetup:
|
||||
def test_queue_depth_collector_uses_celery_app(self) -> None:
|
||||
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = QueueDepthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
def test_redis_health_collector_uses_celery_app(self) -> None:
|
||||
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = RedisHealthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
@@ -39,6 +39,22 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
|
||||
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
|
||||
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -39,6 +39,20 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -39,6 +39,23 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
|
||||
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
|
||||
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -66,10 +66,3 @@ DB_READONLY_PASSWORD=password
|
||||
# Show extra/uncommon connectors
|
||||
# See https://docs.onyx.app/admins/connectors/overview for a full list of connectors
|
||||
SHOW_EXTRA_CONNECTORS=False
|
||||
|
||||
# User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
#SKIP_USERFILE_THRESHOLD=false
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
#SKIP_USERFILE_THRESHOLD_TENANT_IDS=
|
||||
|
||||
@@ -35,6 +35,10 @@ USER_AUTH_SECRET=""
|
||||
|
||||
## Chat Configuration
|
||||
# HARD_DELETE_CHATS=
|
||||
# MAX_ALLOWED_UPLOAD_SIZE_MB=250
|
||||
# Default per-user upload size limit (MB) when no admin value is set.
|
||||
# Automatically clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at runtime.
|
||||
# DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=100
|
||||
|
||||
## Base URL for redirects
|
||||
# WEB_DOMAIN=
|
||||
@@ -42,13 +46,6 @@ USER_AUTH_SECRET=""
|
||||
## Enterprise Features, requires a paid plan and licenses
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false
|
||||
|
||||
## User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
# SKIP_USERFILE_THRESHOLD=false
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
# SKIP_USERFILE_THRESHOLD_TENANT_IDS=
|
||||
|
||||
|
||||
################################################################################
|
||||
## SERVICES CONFIGURATIONS
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.36
|
||||
version: 0.4.38
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
|
||||
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docfetching.replicaCount) 0) }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching-metrics
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
metrics: "true"
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 9092
|
||||
targetPort: metrics
|
||||
protocol: TCP
|
||||
name: metrics
|
||||
selector:
|
||||
{{- include "onyx.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -73,6 +73,10 @@ spec:
|
||||
"-Q",
|
||||
"connector_doc_fetching",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9092
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_docfetching.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
|
||||
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docprocessing.replicaCount) 0) }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing-metrics
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
metrics: "true"
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 9093
|
||||
targetPort: metrics
|
||||
protocol: TCP
|
||||
name: metrics
|
||||
selector:
|
||||
{{- include "onyx.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -73,6 +73,10 @@ spec:
|
||||
"-Q",
|
||||
"docprocessing",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9093
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_docprocessing.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
|
||||
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring-metrics
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
metrics: "true"
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 9096
|
||||
targetPort: metrics
|
||||
protocol: TCP
|
||||
name: metrics
|
||||
selector:
|
||||
{{- include "onyx.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -70,6 +70,10 @@ spec:
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9096
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_monitoring.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -63,6 +63,22 @@ data:
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
# timeout settings
|
||||
proxy_connect_timeout {{ .Values.nginx.timeouts.connect }}s;
|
||||
proxy_send_timeout {{ .Values.nginx.timeouts.send }}s;
|
||||
proxy_read_timeout {{ .Values.nginx.timeouts.read }}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
location ~ ^/(api|openapi\.json)(/.*)?$ {
|
||||
rewrite ^/api(/.*)$ $1 break;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
|
||||
@@ -282,7 +282,7 @@ nginx:
|
||||
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
|
||||
# Workaround: Helm upgrade will restart if the following annotation value changes.
|
||||
podAnnotations:
|
||||
onyx.app/nginx-config-version: "2"
|
||||
onyx.app/nginx-config-version: "3"
|
||||
|
||||
# Propagate DOMAIN into nginx so server_name continues to use the same env var
|
||||
extraEnvs:
|
||||
@@ -1285,11 +1285,5 @@ configMap:
|
||||
DOMAIN: "localhost"
|
||||
# Chat Configs
|
||||
HARD_DELETE_CHATS: ""
|
||||
# User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
SKIP_USERFILE_THRESHOLD: ""
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
|
||||
# Maximum user upload file size in MB for chat/projects uploads
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB: ""
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
|
||||
@@ -144,7 +144,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.7.1",
|
||||
"onyx-devtools==0.7.2",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
|
||||
@@ -28,7 +28,7 @@ Some commands require external tools to be installed and configured:
|
||||
- **uv** - Required for `backend` commands
|
||||
- Install from [docs.astral.sh/uv](https://docs.astral.sh/uv/)
|
||||
|
||||
- **GitHub CLI** (`gh`) - Required for `run-ci` and `cherry-pick` commands
|
||||
- **GitHub CLI** (`gh`) - Required for `run-ci`, `cherry-pick`, and `trace` commands
|
||||
- Install from [cli.github.com](https://cli.github.com/)
|
||||
- Authenticate with `gh auth login`
|
||||
|
||||
@@ -412,6 +412,62 @@ The `compare` subcommand writes a `summary.json` alongside the report with aggre
|
||||
counts (changed, added, removed, unchanged). The HTML report is only generated when
|
||||
visual differences are detected.
|
||||
|
||||
### `trace` - View Playwright Traces from CI
|
||||
|
||||
Download Playwright trace artifacts from a GitHub Actions run and open them
|
||||
with `playwright show-trace`. Traces are only generated for failing tests
|
||||
(`retain-on-failure`).
|
||||
|
||||
```shell
|
||||
ods trace [run-id-or-url]
|
||||
```
|
||||
|
||||
The run can be specified as a numeric run ID, a full GitHub Actions URL, or
|
||||
omitted to find the latest Playwright run for the current branch.
|
||||
|
||||
**Flags:**
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `--branch`, `-b` | | Find latest run for this branch |
|
||||
| `--pr` | | Find latest run for this PR number |
|
||||
| `--project`, `-p` | | Filter to a specific project (`admin`, `exclusive`, `lite`) |
|
||||
| `--list`, `-l` | `false` | List available traces without opening |
|
||||
| `--no-open` | `false` | Download traces but don't open them |
|
||||
|
||||
When multiple traces are found, an interactive picker lets you select which
|
||||
traces to open. Use arrow keys or `j`/`k` to navigate, `space` to toggle,
|
||||
`a` to select all, `n` to deselect all, and `enter` to open. Falls back to a
|
||||
plain-text prompt when no TTY is available.
|
||||
|
||||
Downloaded artifacts are cached in `/tmp/ods-traces/<run-id>/` so repeated
|
||||
invocations for the same run are instant.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```shell
|
||||
# Latest run for the current branch
|
||||
ods trace
|
||||
|
||||
# Specific run ID
|
||||
ods trace 12345678
|
||||
|
||||
# Full GitHub Actions URL
|
||||
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
|
||||
|
||||
# Latest run for a PR
|
||||
ods trace --pr 9500
|
||||
|
||||
# Latest run for a specific branch
|
||||
ods trace --branch main
|
||||
|
||||
# Only download admin project traces
|
||||
ods trace --project admin
|
||||
|
||||
# List traces without opening
|
||||
ods trace --list
|
||||
```
|
||||
|
||||
### Testing Changes Locally (Dry Run)
|
||||
|
||||
Both `run-ci` and `cherry-pick` support `--dry-run` to test without making remote changes:
|
||||
|
||||
@@ -55,6 +55,7 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewLatestStableTagCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
cmd.AddCommand(NewTraceCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
556
tools/ods/cmd/trace.go
Normal file
556
tools/ods/cmd/trace.go
Normal file
@@ -0,0 +1,556 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/git"
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/tui"
|
||||
)
|
||||
|
||||
const playwrightWorkflow = "Run Playwright Tests"
|
||||
|
||||
// TraceOptions holds options for the trace command
|
||||
type TraceOptions struct {
|
||||
Branch string
|
||||
PR string
|
||||
Project string
|
||||
List bool
|
||||
NoOpen bool
|
||||
}
|
||||
|
||||
// traceInfo describes a single trace.zip found in the downloaded artifacts.
|
||||
type traceInfo struct {
|
||||
Path string // absolute path to trace.zip
|
||||
Project string // project group extracted from artifact dir (e.g. "admin", "admin-shard-1")
|
||||
TestDir string // test directory name (human-readable-ish)
|
||||
}
|
||||
|
||||
// NewTraceCommand creates a new trace command
|
||||
func NewTraceCommand() *cobra.Command {
|
||||
opts := &TraceOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "trace [run-id-or-url]",
|
||||
Short: "Download and view Playwright traces from GitHub Actions",
|
||||
Long: `Download Playwright trace artifacts from a GitHub Actions run and open them
|
||||
with 'playwright show-trace'.
|
||||
|
||||
The run can be specified as:
|
||||
- A GitHub Actions run ID (numeric)
|
||||
- A full GitHub Actions run URL
|
||||
- Omitted, to find the latest Playwright run for the current branch
|
||||
|
||||
You can also look up the latest run by branch name or PR number.
|
||||
|
||||
Examples:
|
||||
ods trace # latest run for current branch
|
||||
ods trace 12345678 # specific run ID
|
||||
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
|
||||
ods trace --pr 9500 # latest run for PR #9500
|
||||
ods trace --branch main # latest run for main branch
|
||||
ods trace --project admin # only download admin project traces
|
||||
ods trace --list # list available traces without opening`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runTrace(args, opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Find latest run for this branch")
|
||||
cmd.Flags().StringVar(&opts.PR, "pr", "", "Find latest run for this PR number")
|
||||
cmd.Flags().StringVarP(&opts.Project, "project", "p", "", "Filter to a specific project (admin, exclusive, lite)")
|
||||
cmd.Flags().BoolVarP(&opts.List, "list", "l", false, "List available traces without opening")
|
||||
cmd.Flags().BoolVar(&opts.NoOpen, "no-open", false, "Download traces but don't open them")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ghRun represents a GitHub Actions workflow run from `gh run list`
|
||||
type ghRun struct {
|
||||
DatabaseID int64 `json:"databaseId"`
|
||||
Status string `json:"status"`
|
||||
Conclusion string `json:"conclusion"`
|
||||
HeadBranch string `json:"headBranch"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
func runTrace(args []string, opts *TraceOptions) {
|
||||
git.CheckGitHubCLI()
|
||||
|
||||
runID, err := resolveRunID(args, opts)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to resolve run: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("Using run ID: %s", runID)
|
||||
|
||||
destDir, err := downloadTraceArtifacts(runID, opts.Project)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to download artifacts: %v", err)
|
||||
}
|
||||
|
||||
traces, err := findTraceInfos(destDir, runID)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find traces: %v", err)
|
||||
}
|
||||
|
||||
if len(traces) == 0 {
|
||||
log.Info("No trace files found in the downloaded artifacts.")
|
||||
log.Info("Traces are only generated for failing tests (retain-on-failure).")
|
||||
return
|
||||
}
|
||||
|
||||
projects := groupByProject(traces)
|
||||
|
||||
if opts.List || opts.NoOpen {
|
||||
printTraceList(traces, projects)
|
||||
fmt.Printf("\nTraces downloaded to: %s\n", destDir)
|
||||
return
|
||||
}
|
||||
|
||||
if len(traces) == 1 {
|
||||
openTraces(traces)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
selected := selectTraces(traces, projects)
|
||||
if len(selected) == 0 {
|
||||
return
|
||||
}
|
||||
openTraces(selected)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRunID determines the run ID from the provided arguments and options.
|
||||
func resolveRunID(args []string, opts *TraceOptions) (string, error) {
|
||||
if len(args) == 1 {
|
||||
return parseRunIDFromArg(args[0])
|
||||
}
|
||||
|
||||
if opts.PR != "" {
|
||||
return findLatestRunForPR(opts.PR)
|
||||
}
|
||||
|
||||
branch := opts.Branch
|
||||
if branch == "" {
|
||||
var err error
|
||||
branch, err = git.GetCurrentBranch()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current branch: %w", err)
|
||||
}
|
||||
if branch == "" {
|
||||
return "", fmt.Errorf("detached HEAD; specify a --branch, --pr, or run ID")
|
||||
}
|
||||
log.Infof("Using current branch: %s", branch)
|
||||
}
|
||||
|
||||
return findLatestRunForBranch(branch)
|
||||
}
|
||||
|
||||
var runURLPattern = regexp.MustCompile(`/actions/runs/(\d+)`)
|
||||
|
||||
// parseRunIDFromArg extracts a run ID from either a numeric string or a full URL.
|
||||
func parseRunIDFromArg(arg string) (string, error) {
|
||||
if matched, _ := regexp.MatchString(`^\d+$`, arg); matched {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
matches := runURLPattern.FindStringSubmatch(arg)
|
||||
if matches != nil {
|
||||
return matches[1], nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not parse run ID from %q; expected a numeric ID or GitHub Actions URL", arg)
|
||||
}
|
||||
|
||||
// findLatestRunForBranch finds the most recent Playwright workflow run for a branch.
|
||||
func findLatestRunForBranch(branch string) (string, error) {
|
||||
log.Infof("Looking up latest Playwright run for branch: %s", branch)
|
||||
|
||||
cmd := exec.Command("gh", "run", "list",
|
||||
"--workflow", playwrightWorkflow,
|
||||
"--branch", branch,
|
||||
"--limit", "1",
|
||||
"--json", "databaseId,status,conclusion,headBranch,url",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", ghError(err, "gh run list failed")
|
||||
}
|
||||
|
||||
var runs []ghRun
|
||||
if err := json.Unmarshal(output, &runs); err != nil {
|
||||
return "", fmt.Errorf("failed to parse run list: %w", err)
|
||||
}
|
||||
|
||||
if len(runs) == 0 {
|
||||
return "", fmt.Errorf("no Playwright runs found for branch %q", branch)
|
||||
}
|
||||
|
||||
run := runs[0]
|
||||
log.Infof("Found run: %s (status: %s, conclusion: %s)", run.URL, run.Status, run.Conclusion)
|
||||
return fmt.Sprintf("%d", run.DatabaseID), nil
|
||||
}
|
||||
|
||||
// findLatestRunForPR finds the most recent Playwright workflow run for a PR.
|
||||
func findLatestRunForPR(prNumber string) (string, error) {
|
||||
log.Infof("Looking up branch for PR #%s", prNumber)
|
||||
|
||||
cmd := exec.Command("gh", "pr", "view", prNumber,
|
||||
"--json", "headRefName",
|
||||
"--jq", ".headRefName",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", ghError(err, "gh pr view failed")
|
||||
}
|
||||
|
||||
branch := strings.TrimSpace(string(output))
|
||||
if branch == "" {
|
||||
return "", fmt.Errorf("could not determine branch for PR #%s", prNumber)
|
||||
}
|
||||
|
||||
log.Infof("PR #%s is on branch: %s", prNumber, branch)
|
||||
return findLatestRunForBranch(branch)
|
||||
}
|
||||
|
||||
// downloadTraceArtifacts downloads playwright trace artifacts for a run.
|
||||
// Returns the path to the download directory.
|
||||
func downloadTraceArtifacts(runID string, project string) (string, error) {
|
||||
cacheKey := runID
|
||||
if project != "" {
|
||||
cacheKey = runID + "-" + project
|
||||
}
|
||||
destDir := filepath.Join(os.TempDir(), "ods-traces", cacheKey)
|
||||
|
||||
// Reuse a previous download if traces exist
|
||||
if info, err := os.Stat(destDir); err == nil && info.IsDir() {
|
||||
traces, _ := findTraces(destDir)
|
||||
if len(traces) > 0 {
|
||||
log.Infof("Using cached download at %s", destDir)
|
||||
return destDir, nil
|
||||
}
|
||||
_ = os.RemoveAll(destDir)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory %s: %w", destDir, err)
|
||||
}
|
||||
|
||||
ghArgs := []string{"run", "download", runID, "--dir", destDir}
|
||||
|
||||
if project != "" {
|
||||
ghArgs = append(ghArgs, "--pattern", fmt.Sprintf("playwright-test-results-%s-*", project))
|
||||
} else {
|
||||
ghArgs = append(ghArgs, "--pattern", "playwright-test-results-*")
|
||||
}
|
||||
|
||||
log.Infof("Downloading trace artifacts...")
|
||||
log.Debugf("Running: gh %s", strings.Join(ghArgs, " "))
|
||||
|
||||
cmd := exec.Command("gh", ghArgs...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
_ = os.RemoveAll(destDir)
|
||||
return "", fmt.Errorf("gh run download failed: %w\nMake sure the run ID is correct and the artifacts haven't expired (30 day retention)", err)
|
||||
}
|
||||
|
||||
return destDir, nil
|
||||
}
|
||||
|
||||
// findTraces recursively finds all trace.zip files under a directory.
|
||||
func findTraces(root string) ([]string, error) {
|
||||
var traces []string
|
||||
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && info.Name() == "trace.zip" {
|
||||
traces = append(traces, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return traces, err
|
||||
}
|
||||
|
||||
// findTraceInfos walks the download directory and returns structured trace info.
|
||||
// Expects: destDir/{artifact-dir}/{test-dir}/trace.zip
|
||||
func findTraceInfos(destDir, runID string) ([]traceInfo, error) {
|
||||
var traces []traceInfo
|
||||
err := filepath.Walk(destDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() || info.Name() != "trace.zip" {
|
||||
return nil
|
||||
}
|
||||
|
||||
rel, _ := filepath.Rel(destDir, path)
|
||||
parts := strings.SplitN(rel, string(filepath.Separator), 3)
|
||||
|
||||
artifactDir := ""
|
||||
testDir := filepath.Base(filepath.Dir(path))
|
||||
if len(parts) >= 2 {
|
||||
artifactDir = parts[0]
|
||||
testDir = parts[1]
|
||||
}
|
||||
|
||||
traces = append(traces, traceInfo{
|
||||
Path: path,
|
||||
Project: extractProject(artifactDir, runID),
|
||||
TestDir: testDir,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
sort.Slice(traces, func(i, j int) bool {
|
||||
pi, pj := projectSortKey(traces[i].Project), projectSortKey(traces[j].Project)
|
||||
if pi != pj {
|
||||
return pi < pj
|
||||
}
|
||||
return traces[i].TestDir < traces[j].TestDir
|
||||
})
|
||||
|
||||
return traces, err
|
||||
}
|
||||
|
||||
// extractProject derives a project group from an artifact directory name.
|
||||
// e.g. "playwright-test-results-admin-12345" -> "admin"
|
||||
//
|
||||
// "playwright-test-results-admin-shard-1-12345" -> "admin-shard-1"
|
||||
func extractProject(artifactDir, runID string) string {
|
||||
name := strings.TrimPrefix(artifactDir, "playwright-test-results-")
|
||||
name = strings.TrimSuffix(name, "-"+runID)
|
||||
if name == "" {
|
||||
return artifactDir
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// projectSortKey returns a sort-friendly key that orders admin < exclusive < lite.
|
||||
func projectSortKey(project string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(project, "admin"):
|
||||
return "0-" + project
|
||||
case strings.HasPrefix(project, "exclusive"):
|
||||
return "1-" + project
|
||||
case strings.HasPrefix(project, "lite"):
|
||||
return "2-" + project
|
||||
default:
|
||||
return "3-" + project
|
||||
}
|
||||
}
|
||||
|
||||
// groupByProject returns an ordered list of unique project names found in traces.
|
||||
func groupByProject(traces []traceInfo) []string {
|
||||
seen := map[string]bool{}
|
||||
var projects []string
|
||||
for _, t := range traces {
|
||||
if !seen[t.Project] {
|
||||
seen[t.Project] = true
|
||||
projects = append(projects, t.Project)
|
||||
}
|
||||
}
|
||||
sort.Slice(projects, func(i, j int) bool {
|
||||
return projectSortKey(projects[i]) < projectSortKey(projects[j])
|
||||
})
|
||||
return projects
|
||||
}
|
||||
|
||||
// printTraceList displays traces grouped by project.
|
||||
func printTraceList(traces []traceInfo, projects []string) {
|
||||
fmt.Printf("\nFound %d trace(s) across %d project(s):\n", len(traces), len(projects))
|
||||
|
||||
idx := 1
|
||||
for _, proj := range projects {
|
||||
count := 0
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
count++
|
||||
}
|
||||
}
|
||||
fmt.Printf("\n %s (%d):\n", proj, count)
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
fmt.Printf(" [%2d] %s\n", idx, t.TestDir)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// selectTraces tries the TUI picker first, falling back to a plain-text
|
||||
// prompt when the terminal cannot be initialised (e.g. piped output).
|
||||
func selectTraces(traces []traceInfo, projects []string) []traceInfo {
|
||||
// Build picker groups in the same order as the sorted traces slice.
|
||||
var groups []tui.PickerGroup
|
||||
for _, proj := range projects {
|
||||
var items []string
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
items = append(items, t.TestDir)
|
||||
}
|
||||
}
|
||||
groups = append(groups, tui.PickerGroup{Label: proj, Items: items})
|
||||
}
|
||||
|
||||
indices, err := tui.Pick(groups)
|
||||
if err != nil {
|
||||
// Terminal not available — fall back to text prompt
|
||||
log.Debugf("TUI picker unavailable: %v", err)
|
||||
printTraceList(traces, projects)
|
||||
return promptTraceSelection(traces, projects)
|
||||
}
|
||||
if indices == nil {
|
||||
return nil // user cancelled
|
||||
}
|
||||
|
||||
selected := make([]traceInfo, len(indices))
|
||||
for i, idx := range indices {
|
||||
selected[i] = traces[idx]
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
// promptTraceSelection asks the user which traces to open via plain text.
|
||||
// Accepts numbers (1,3,5), ranges (1-5), "all", or a project name.
|
||||
func promptTraceSelection(traces []traceInfo, projects []string) []traceInfo {
|
||||
fmt.Printf("\nOpen which traces? (e.g. 1,3,5 | 1-5 | all | %s): ", strings.Join(projects, " | "))
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read input: %v", err)
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" || strings.EqualFold(input, "all") {
|
||||
return traces
|
||||
}
|
||||
|
||||
// Check if input matches a project name
|
||||
for _, proj := range projects {
|
||||
if strings.EqualFold(input, proj) {
|
||||
var selected []traceInfo
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
selected = append(selected, t)
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
}
|
||||
|
||||
// Parse as number/range selection
|
||||
indices := parseTraceSelection(input, len(traces))
|
||||
if len(indices) == 0 {
|
||||
log.Warn("No valid selection; opening all traces")
|
||||
return traces
|
||||
}
|
||||
|
||||
selected := make([]traceInfo, len(indices))
|
||||
for i, idx := range indices {
|
||||
selected[i] = traces[idx]
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
// parseTraceSelection parses a comma-separated list of numbers and ranges into
|
||||
// 0-based indices. Input is 1-indexed (matches display). Out-of-range values
|
||||
// are silently ignored.
|
||||
func parseTraceSelection(input string, max int) []int {
|
||||
var result []int
|
||||
seen := map[int]bool{}
|
||||
|
||||
for _, part := range strings.Split(input, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if idx := strings.Index(part, "-"); idx > 0 {
|
||||
lo, err1 := strconv.Atoi(strings.TrimSpace(part[:idx]))
|
||||
hi, err2 := strconv.Atoi(strings.TrimSpace(part[idx+1:]))
|
||||
if err1 != nil || err2 != nil {
|
||||
continue
|
||||
}
|
||||
for i := lo; i <= hi; i++ {
|
||||
zi := i - 1
|
||||
if zi >= 0 && zi < max && !seen[zi] {
|
||||
result = append(result, zi)
|
||||
seen[zi] = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
n, err := strconv.Atoi(part)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
zi := n - 1
|
||||
if zi >= 0 && zi < max && !seen[zi] {
|
||||
result = append(result, zi)
|
||||
seen[zi] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// openTraces opens the selected traces with playwright show-trace,
|
||||
// running npx from the web/ directory to use the project's Playwright version.
|
||||
func openTraces(traces []traceInfo) {
|
||||
tracePaths := make([]string, len(traces))
|
||||
for i, t := range traces {
|
||||
tracePaths[i] = t.Path
|
||||
}
|
||||
|
||||
args := append([]string{"playwright", "show-trace"}, tracePaths...)
|
||||
|
||||
log.Infof("Opening %d trace(s) with playwright show-trace...", len(traces))
|
||||
cmd := exec.Command("npx", args...)
|
||||
|
||||
// Run from web/ to pick up the locally-installed Playwright version
|
||||
if root, err := paths.GitRoot(); err == nil {
|
||||
cmd.Dir = filepath.Join(root, "web")
|
||||
}
|
||||
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
// Normal exit (e.g. user closed the window) — just log and return
|
||||
// so the picker loop can continue.
|
||||
log.Debugf("playwright exited with code %d", exitErr.ExitCode())
|
||||
return
|
||||
}
|
||||
log.Errorf("playwright show-trace failed: %v\nMake sure Playwright is installed (npx playwright install)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ghError wraps a gh CLI error with stderr output.
|
||||
func ghError(err error, msg string) error {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return fmt.Errorf("%s: %w: %s", msg, err, string(exitErr.Stderr))
|
||||
}
|
||||
return fmt.Errorf("%s: %w", msg, err)
|
||||
}
|
||||
@@ -3,13 +3,19 @@ module github.com/onyx-dot-app/onyx/tools/ods
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/gdamore/tcell/v2 v2.13.8
|
||||
github.com/jmelahman/tag v0.5.2
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/sirupsen/logrus v1.9.4
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/pflag v1.0.10
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/term v0.41.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,30 +1,68 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
|
||||
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
||||
github.com/gdamore/tcell/v2 v2.13.8 h1:Mys/Kl5wfC/GcC5Cx4C2BIQH9dbnhnkPgS9/wF3RlfU=
|
||||
github.com/gdamore/tcell/v2 v2.13.8/go.mod h1:+Wfe208WDdB7INEtCsNrAN6O2m+wsTPk1RAovjaILlo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jmelahman/tag v0.5.2 h1:g6A/aHehu5tkA31mPoDsXBNr1FigZ9A82Y8WVgb/WsM=
|
||||
github.com/jmelahman/tag v0.5.2/go.mod h1:qmuqk19B1BKkpcg3kn7l/Eey+UqucLxgOWkteUGiG4Q=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
|
||||
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
419
tools/ods/internal/tui/picker.go
Normal file
419
tools/ods/internal/tui/picker.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
)
|
||||
|
||||
// PickerGroup represents a labelled group of selectable items.
|
||||
type PickerGroup struct {
|
||||
Label string
|
||||
Items []string
|
||||
}
|
||||
|
||||
// entry is a single row in the picker (either a group header or an item).
|
||||
type entry struct {
|
||||
label string
|
||||
isHeader bool
|
||||
selected bool
|
||||
groupIdx int
|
||||
flatIdx int // index across all items (ignoring headers), -1 for headers
|
||||
}
|
||||
|
||||
// Pick shows a full-screen grouped multi-select picker.
|
||||
// All items start deselected. Returns the flat indices of selected items
|
||||
// (0-based, spanning all groups in order). Returns nil if cancelled.
|
||||
// Returns a non-nil error if the terminal cannot be initialised, in which
|
||||
// case the caller should fall back to a simpler prompt.
|
||||
func Pick(groups []PickerGroup) ([]int, error) {
|
||||
screen, err := tcell.NewScreen()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := screen.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer screen.Fini()
|
||||
|
||||
entries := buildEntries(groups)
|
||||
totalItems := countItems(entries)
|
||||
cursor := firstSelectableIndex(entries)
|
||||
offset := 0
|
||||
|
||||
for {
|
||||
w, h := screen.Size()
|
||||
selectedCount := countSelected(entries)
|
||||
|
||||
drawPicker(screen, entries, groups, cursor, offset, w, h, selectedCount, totalItems)
|
||||
screen.Show()
|
||||
|
||||
ev := screen.PollEvent()
|
||||
switch ev := ev.(type) {
|
||||
case *tcell.EventResize:
|
||||
screen.Sync()
|
||||
case *tcell.EventKey:
|
||||
switch action := keyAction(ev); action {
|
||||
case actionQuit:
|
||||
return nil, nil
|
||||
case actionConfirm:
|
||||
if countSelected(entries) > 0 {
|
||||
return collectSelected(entries), nil
|
||||
}
|
||||
case actionUp:
|
||||
if cursor > 0 {
|
||||
cursor--
|
||||
}
|
||||
case actionDown:
|
||||
if cursor < len(entries)-1 {
|
||||
cursor++
|
||||
}
|
||||
case actionTop:
|
||||
cursor = 0
|
||||
case actionBottom:
|
||||
if len(entries) == 0 {
|
||||
cursor = 0
|
||||
} else {
|
||||
cursor = len(entries) - 1
|
||||
}
|
||||
case actionPageUp:
|
||||
listHeight := h - headerLines - footerLines
|
||||
cursor -= listHeight
|
||||
if cursor < 0 {
|
||||
cursor = 0
|
||||
}
|
||||
case actionPageDown:
|
||||
listHeight := h - headerLines - footerLines
|
||||
cursor += listHeight
|
||||
if cursor >= len(entries) {
|
||||
cursor = len(entries) - 1
|
||||
}
|
||||
case actionToggle:
|
||||
toggleAtCursor(entries, cursor)
|
||||
case actionAll:
|
||||
setAll(entries, true)
|
||||
case actionNone:
|
||||
setAll(entries, false)
|
||||
}
|
||||
|
||||
// Keep the cursor visible
|
||||
listHeight := h - headerLines - footerLines
|
||||
if listHeight < 1 {
|
||||
listHeight = 1
|
||||
}
|
||||
if cursor < offset {
|
||||
offset = cursor
|
||||
}
|
||||
if cursor >= offset+listHeight {
|
||||
offset = cursor - listHeight + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- actions ----------------------------------------------------------------
|
||||
|
||||
type action int
|
||||
|
||||
const (
|
||||
actionNoop action = iota
|
||||
actionQuit
|
||||
actionConfirm
|
||||
actionUp
|
||||
actionDown
|
||||
actionTop
|
||||
actionBottom
|
||||
actionPageUp
|
||||
actionPageDown
|
||||
actionToggle
|
||||
actionAll
|
||||
actionNone
|
||||
)
|
||||
|
||||
func keyAction(ev *tcell.EventKey) action {
|
||||
switch ev.Key() {
|
||||
case tcell.KeyEscape, tcell.KeyCtrlC:
|
||||
return actionQuit
|
||||
case tcell.KeyEnter:
|
||||
return actionConfirm
|
||||
case tcell.KeyUp:
|
||||
return actionUp
|
||||
case tcell.KeyDown:
|
||||
return actionDown
|
||||
case tcell.KeyHome:
|
||||
return actionTop
|
||||
case tcell.KeyEnd:
|
||||
return actionBottom
|
||||
case tcell.KeyPgUp:
|
||||
return actionPageUp
|
||||
case tcell.KeyPgDn:
|
||||
return actionPageDown
|
||||
case tcell.KeyRune:
|
||||
switch ev.Rune() {
|
||||
case 'q':
|
||||
return actionQuit
|
||||
case ' ':
|
||||
return actionToggle
|
||||
case 'j':
|
||||
return actionDown
|
||||
case 'k':
|
||||
return actionUp
|
||||
case 'g':
|
||||
return actionTop
|
||||
case 'G':
|
||||
return actionBottom
|
||||
case 'a':
|
||||
return actionAll
|
||||
case 'n':
|
||||
return actionNone
|
||||
}
|
||||
}
|
||||
return actionNoop
|
||||
}
|
||||
|
||||
// --- data helpers ------------------------------------------------------------
|
||||
|
||||
func buildEntries(groups []PickerGroup) []entry {
|
||||
var entries []entry
|
||||
flat := 0
|
||||
for gi, g := range groups {
|
||||
entries = append(entries, entry{
|
||||
label: g.Label,
|
||||
isHeader: true,
|
||||
groupIdx: gi,
|
||||
flatIdx: -1,
|
||||
})
|
||||
for _, item := range g.Items {
|
||||
entries = append(entries, entry{
|
||||
label: item,
|
||||
isHeader: false,
|
||||
selected: false,
|
||||
groupIdx: gi,
|
||||
flatIdx: flat,
|
||||
})
|
||||
flat++
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func firstSelectableIndex(entries []entry) int {
|
||||
for i, e := range entries {
|
||||
if !e.isHeader {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func countItems(entries []entry) int {
|
||||
n := 0
|
||||
for _, e := range entries {
|
||||
if !e.isHeader {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func countSelected(entries []entry) int {
|
||||
n := 0
|
||||
for _, e := range entries {
|
||||
if !e.isHeader && e.selected {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func collectSelected(entries []entry) []int {
|
||||
var result []int
|
||||
for _, e := range entries {
|
||||
if !e.isHeader && e.selected {
|
||||
result = append(result, e.flatIdx)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toggleAtCursor(entries []entry, cursor int) {
|
||||
if cursor < 0 || cursor >= len(entries) {
|
||||
return
|
||||
}
|
||||
e := entries[cursor]
|
||||
if e.isHeader {
|
||||
// Toggle entire group: if all selected -> deselect all, else select all
|
||||
allSelected := true
|
||||
for _, e2 := range entries {
|
||||
if !e2.isHeader && e2.groupIdx == e.groupIdx && !e2.selected {
|
||||
allSelected = false
|
||||
break
|
||||
}
|
||||
}
|
||||
for i := range entries {
|
||||
if !entries[i].isHeader && entries[i].groupIdx == e.groupIdx {
|
||||
entries[i].selected = !allSelected
|
||||
}
|
||||
}
|
||||
} else {
|
||||
entries[cursor].selected = !entries[cursor].selected
|
||||
}
|
||||
}
|
||||
|
||||
func setAll(entries []entry, selected bool) {
|
||||
for i := range entries {
|
||||
if !entries[i].isHeader {
|
||||
entries[i].selected = selected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- drawing ----------------------------------------------------------------
|
||||
|
||||
const (
|
||||
headerLines = 2 // title + blank line
|
||||
footerLines = 2 // blank line + keybinds
|
||||
)
|
||||
|
||||
var (
|
||||
styleDefault = tcell.StyleDefault
|
||||
styleTitle = tcell.StyleDefault.Bold(true)
|
||||
styleGroup = tcell.StyleDefault.Bold(true).Foreground(tcell.ColorTeal)
|
||||
styleGroupCur = tcell.StyleDefault.Bold(true).Foreground(tcell.ColorTeal).Reverse(true)
|
||||
styleCheck = tcell.StyleDefault.Foreground(tcell.ColorGreen).Bold(true)
|
||||
styleUncheck = tcell.StyleDefault.Dim(true)
|
||||
styleItem = tcell.StyleDefault
|
||||
styleItemCur = tcell.StyleDefault.Bold(true).Underline(true)
|
||||
styleCheckCur = tcell.StyleDefault.Foreground(tcell.ColorGreen).Bold(true).Underline(true)
|
||||
styleUncheckCur = tcell.StyleDefault.Dim(true).Underline(true)
|
||||
styleFooter = tcell.StyleDefault.Dim(true)
|
||||
)
|
||||
|
||||
func drawPicker(
|
||||
screen tcell.Screen,
|
||||
entries []entry,
|
||||
groups []PickerGroup,
|
||||
cursor, offset, w, h, selectedCount, totalItems int,
|
||||
) {
|
||||
screen.Clear()
|
||||
|
||||
// Title
|
||||
title := fmt.Sprintf(" Select traces to open (%d/%d selected)", selectedCount, totalItems)
|
||||
drawLine(screen, 0, 0, w, title, styleTitle)
|
||||
|
||||
// List area
|
||||
listHeight := h - headerLines - footerLines
|
||||
if listHeight < 1 {
|
||||
listHeight = 1
|
||||
}
|
||||
|
||||
for i := 0; i < listHeight; i++ {
|
||||
ei := offset + i
|
||||
if ei >= len(entries) {
|
||||
break
|
||||
}
|
||||
y := headerLines + i
|
||||
renderEntry(screen, entries, groups, ei, cursor, w, y)
|
||||
}
|
||||
|
||||
// Scrollbar hint
|
||||
if len(entries) > listHeight {
|
||||
drawScrollbar(screen, w-1, headerLines, listHeight, offset, len(entries))
|
||||
}
|
||||
|
||||
// Footer
|
||||
footerY := h - 1
|
||||
footer := " ↑/↓ move space toggle a all n none enter open q/esc quit"
|
||||
drawLine(screen, 0, footerY, w, footer, styleFooter)
|
||||
}
|
||||
|
||||
func renderEntry(screen tcell.Screen, entries []entry, groups []PickerGroup, ei, cursor, w, y int) {
|
||||
e := entries[ei]
|
||||
isCursor := ei == cursor
|
||||
|
||||
if e.isHeader {
|
||||
groupSelected := 0
|
||||
groupTotal := 0
|
||||
for _, e2 := range entries {
|
||||
if !e2.isHeader && e2.groupIdx == e.groupIdx {
|
||||
groupTotal++
|
||||
if e2.selected {
|
||||
groupSelected++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
label := fmt.Sprintf(" %s (%d/%d)", e.label, groupSelected, groupTotal)
|
||||
style := styleGroup
|
||||
if isCursor {
|
||||
style = styleGroupCur
|
||||
}
|
||||
drawLine(screen, 0, y, w, label, style)
|
||||
return
|
||||
}
|
||||
|
||||
// Item row: " [x] label" or " > [x] label"
|
||||
prefix := " "
|
||||
if isCursor {
|
||||
prefix = " > "
|
||||
}
|
||||
|
||||
check := "[ ]"
|
||||
cStyle := styleUncheck
|
||||
iStyle := styleItem
|
||||
if isCursor {
|
||||
cStyle = styleUncheckCur
|
||||
iStyle = styleItemCur
|
||||
}
|
||||
if e.selected {
|
||||
check = "[x]"
|
||||
cStyle = styleCheck
|
||||
if isCursor {
|
||||
cStyle = styleCheckCur
|
||||
}
|
||||
}
|
||||
|
||||
x := drawStr(screen, 0, y, w, prefix, iStyle)
|
||||
x = drawStr(screen, x, y, w, check, cStyle)
|
||||
drawStr(screen, x, y, w, " "+e.label, iStyle)
|
||||
}
|
||||
|
||||
func drawScrollbar(screen tcell.Screen, x, top, height, offset, total int) {
|
||||
if total <= height || height < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
thumbSize := max(1, height*height/total)
|
||||
thumbPos := top + offset*height/total
|
||||
|
||||
for y := top; y < top+height; y++ {
|
||||
ch := '│'
|
||||
style := styleDefault.Dim(true)
|
||||
if y >= thumbPos && y < thumbPos+thumbSize {
|
||||
ch = '┃'
|
||||
style = styleDefault
|
||||
}
|
||||
screen.SetContent(x, y, ch, nil, style)
|
||||
}
|
||||
}
|
||||
|
||||
// drawLine fills an entire row starting at x=startX, padding to width w.
|
||||
func drawLine(screen tcell.Screen, startX, y, w int, s string, style tcell.Style) {
|
||||
x := drawStr(screen, startX, y, w, s, style)
|
||||
// Clear the rest of the line
|
||||
for ; x < w; x++ {
|
||||
screen.SetContent(x, y, ' ', nil, style)
|
||||
}
|
||||
}
|
||||
|
||||
// drawStr writes a string at (x, y) up to maxX and returns the next x position.
|
||||
func drawStr(screen tcell.Screen, x, y, maxX int, s string, style tcell.Style) int {
|
||||
for _, ch := range s {
|
||||
if x >= maxX {
|
||||
break
|
||||
}
|
||||
screen.SetContent(x, y, ch, nil, style)
|
||||
x++
|
||||
}
|
||||
return x
|
||||
}
|
||||
16
uv.lock
generated
16
uv.lock
generated
@@ -4458,7 +4458,7 @@ requires-dist = [
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.7.1" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.7.2" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
@@ -4563,19 +4563,19 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/9d/74bcd02583706bdf90c8ac9084eb60bd71d0671392152410ab21b7b68ea1/onyx_devtools-0.7.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:178385dce0b413fd2a1f761055a99f556ec536ef5c32963fc273e751813621eb", size = 4007974, upload-time = "2026-03-17T21:10:39.267Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/f8/d8ddb32120428c083c60eb07244479da6e07eaebd31847658a049ab33815/onyx_devtools-0.7.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7960ae6e440ebf1584e02d9e1d0c9ef543b1d54c2584cdcace15695aec3121b2", size = 3696924, upload-time = "2026-03-17T21:10:50.716Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/21/1e427280066db42ff9dd5f34c70b9dca5d9781f96d0d9a88aaa454fdb432/onyx_devtools-0.7.1-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6785dda88ca0a3d8464a9bfab76a253ed90da89d53a9c4a67227980f37df1ccf", size = 3568300, upload-time = "2026-03-17T21:10:41.997Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/0e/afbbe1164b3d016ddb5352353cb2541eef5a8b2c04e8f02d5d1319cb8b8c/onyx_devtools-0.7.1-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:9e77f2b725c0c00061a3dda5eba199404b51638cec0bf54fc7611fee1f26db34", size = 3974668, upload-time = "2026-03-17T21:10:43.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/a5/22840643289ef4ca83931b7a79fba8f1db7e626b4b870d4b4f8206c4ff5f/onyx_devtools-0.7.1-py3-none-win_amd64.whl", hash = "sha256:de37daa0e4db9b5dccf94408a3422be4f821e380ab70081bd1032cec1e3c91e6", size = 4078640, upload-time = "2026-03-17T21:10:40.275Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/c1/a0295506a521d9942b0f55523781a113e4555420d800a386d5a2eb46a7ad/onyx_devtools-0.7.1-py3-none-win_arm64.whl", hash = "sha256:ab88c53ebda6dff27350316b4ac9bd5f258cd586c2109971a9d976411e1e22ea", size = 3636787, upload-time = "2026-03-17T21:10:37.492Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/b0/765ed49157470e8ccc8ab89e6a896ade50cde3aa2a494662ad4db92a48c4/onyx_devtools-0.7.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:553a2b5e61b29b7913c991c8d5aed78f930f0f81a0f42229c6a8de2b1e8ff57e", size = 4203859, upload-time = "2026-03-27T15:09:49.63Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/9d/bba0a44a16d2fc27e5441aaf10727e10514e7a49bce70eca02bced566eb9/onyx_devtools-0.7.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5cf0782dca8b3d861de9e18e65e990cfce5161cd559df44d8fabd3fefd54fdcd", size = 3879750, upload-time = "2026-03-27T15:09:42.413Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/d8/c5725e8af14c74fe0aeed29e4746400bb3c0a078fd1240df729dc6432b84/onyx_devtools-0.7.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:9a0d67373e16b4fbb38a5290c0d9dfd4cfa837e5da0c165b32841b9d37f7455b", size = 3743529, upload-time = "2026-03-27T15:09:44.546Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/82/b7c398a21dbc3e14fd7a29e49caa86b1bc0f8d7c75c051514785441ab779/onyx_devtools-0.7.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:794af14b2de575d0ae41b94551399eca8f8ba9b950c5db7acb7612767fd228f9", size = 4166562, upload-time = "2026-03-27T15:09:49.471Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/76/be129e2baafc91fe792d919b1f4d73fc943ba9c2b728a60f1fb98e0c115a/onyx_devtools-0.7.2-py3-none-win_amd64.whl", hash = "sha256:83b3eb84df58d865e4f714222a5fab3ea464836e2c8690569454a940bbb651ff", size = 4282270, upload-time = "2026-03-27T15:09:44.676Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/72/29b8c8dbcf069c56475f00511f04c4aaa5ba3faba1dfc8276107d4b3ef7f/onyx_devtools-0.7.2-py3-none-win_arm64.whl", hash = "sha256:62f0836624ee6a5b31e64fd93162e7fce142ac8a4f959607e411824bc2b88174", size = 3823053, upload-time = "2026-03-27T15:09:43.546Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -73,6 +73,15 @@ interface ContentMdProps {
|
||||
/** When `true`, the title color hooks into `Interactive`'s `--interactive-foreground` variable. */
|
||||
withInteractive?: boolean;
|
||||
|
||||
/** Optional class name applied to the title element. */
|
||||
titleClassName?: string;
|
||||
|
||||
/** Optional class name applied to the icon element. */
|
||||
iconClassName?: string;
|
||||
|
||||
/** Content rendered below the description, indented to align with it. */
|
||||
bottomChildren?: React.ReactNode;
|
||||
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
}
|
||||
@@ -146,6 +155,9 @@ function ContentMd({
|
||||
tag,
|
||||
sizePreset = "main-ui",
|
||||
withInteractive,
|
||||
titleClassName,
|
||||
iconClassName,
|
||||
bottomChildren,
|
||||
ref,
|
||||
}: ContentMdProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
@@ -184,7 +196,11 @@ function ContentMd({
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className={cn("opal-content-md-icon", config.iconColorClass)}
|
||||
className={cn(
|
||||
"opal-content-md-icon",
|
||||
config.iconColorClass,
|
||||
iconClassName
|
||||
)}
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
@@ -227,7 +243,8 @@ function ContentMd({
|
||||
"opal-content-md-title",
|
||||
config.titleFont,
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
editable && "cursor-pointer",
|
||||
titleClassName
|
||||
)}
|
||||
title={toPlainString(title)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
@@ -295,6 +312,13 @@ function ContentMd({
|
||||
{resolveStr(description)}
|
||||
</div>
|
||||
)}
|
||||
{bottomChildren && (
|
||||
<div
|
||||
style={Icon ? { paddingLeft: config.descriptionIndent } : undefined}
|
||||
>
|
||||
{bottomChildren}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -138,6 +138,12 @@ type MdContentProps = ContentBaseProps & {
|
||||
auxIcon?: "info-gray" | "info-blue" | "warning" | "error";
|
||||
/** Tag rendered beside the title. */
|
||||
tag?: TagProps;
|
||||
/** Optional class name applied to the title element. */
|
||||
titleClassName?: string;
|
||||
/** Optional class name applied to the icon element. */
|
||||
iconClassName?: string;
|
||||
/** Content rendered below the description, indented to align with it. */
|
||||
bottomChildren?: React.ReactNode;
|
||||
};
|
||||
|
||||
/** ContentSm does not support descriptions or inline editing. */
|
||||
|
||||
24
web/package-lock.json
generated
24
web/package-lock.json
generated
@@ -7901,7 +7901,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/anymatch/node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
@@ -10701,7 +10703,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/handlebars": {
|
||||
"version": "4.7.8",
|
||||
"version": "4.7.9",
|
||||
"resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.9.tgz",
|
||||
"integrity": "sha512-4E71E0rpOaQuJR2A3xDZ+GM1HyWYv1clR58tC8emQNeQe3RH7MAzSbat+V0wG78LQBo6m6bzSG/L4pBuCsgnUQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -12555,7 +12559,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/jest-util/node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
@@ -13881,7 +13887,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/micromatch/node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
@@ -15001,7 +15009,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
@@ -15889,7 +15899,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/readdirp/node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
|
||||
@@ -37,6 +37,7 @@ export interface Settings {
|
||||
// User Knowledge settings
|
||||
user_knowledge_enabled?: boolean;
|
||||
user_file_max_upload_size_mb?: number | null;
|
||||
file_token_count_threshold_k?: number | null;
|
||||
|
||||
// Connector settings
|
||||
show_extra_connectors?: boolean;
|
||||
@@ -68,6 +69,12 @@ export interface Settings {
|
||||
|
||||
// Application version from the ONYX_VERSION env var on the server.
|
||||
version?: string | null;
|
||||
// Hard ceiling for user_file_max_upload_size_mb, derived from env var.
|
||||
max_allowed_upload_size_mb?: number;
|
||||
|
||||
// Factory defaults for the restore button.
|
||||
default_user_file_max_upload_size_mb?: number;
|
||||
default_file_token_count_threshold_k?: number;
|
||||
}
|
||||
|
||||
export enum NotificationType {
|
||||
|
||||
@@ -85,8 +85,6 @@ function buildFileKey(file: File): string {
|
||||
return `${file.size}|${namePrefix}`;
|
||||
}
|
||||
|
||||
const DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = 50;
|
||||
|
||||
interface ProjectsContextType {
|
||||
projects: Project[];
|
||||
recentFiles: ProjectFile[];
|
||||
@@ -341,21 +339,20 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
|
||||
onFailure?: (failedTempIds: string[]) => void
|
||||
): Promise<ProjectFile[]> => {
|
||||
const rawMax = settingsContext?.settings?.user_file_max_upload_size_mb;
|
||||
const maxUploadSizeMb =
|
||||
rawMax && rawMax > 0 ? rawMax : DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB;
|
||||
const maxUploadSizeBytes = maxUploadSizeMb * 1024 * 1024;
|
||||
|
||||
const oversizedFiles = files.filter(
|
||||
(file) => file.size > maxUploadSizeBytes
|
||||
);
|
||||
const validFiles = files.filter(
|
||||
(file) => file.size <= maxUploadSizeBytes
|
||||
);
|
||||
const oversizedFiles =
|
||||
rawMax && rawMax > 0
|
||||
? files.filter((file) => file.size > rawMax * 1024 * 1024)
|
||||
: [];
|
||||
const validFiles =
|
||||
rawMax && rawMax > 0
|
||||
? files.filter((file) => file.size <= rawMax * 1024 * 1024)
|
||||
: files;
|
||||
|
||||
if (oversizedFiles.length > 0) {
|
||||
const skippedNames = oversizedFiles.map((file) => file.name).join(", ");
|
||||
toast.warning(
|
||||
`Skipped ${oversizedFiles.length} oversized file(s) (>${maxUploadSizeMb} MB): ${skippedNames}`
|
||||
`Skipped ${oversizedFiles.length} oversized file(s) (>${rawMax} MB): ${skippedNames}`
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
SvgFold,
|
||||
SvgExternalLink,
|
||||
SvgAlertCircle,
|
||||
SvgRefreshCw,
|
||||
} from "@opal/icons";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import { Content } from "@opal/layouts";
|
||||
@@ -54,6 +55,7 @@ import * as ExpandableCard from "@/layouts/expandable-card-layouts";
|
||||
import * as ActionsLayouts from "@/layouts/actions-layouts";
|
||||
import { getActionIcon } from "@/lib/tools/mcpUtils";
|
||||
import { Disabled } from "@opal/core";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import useFilter from "@/hooks/useFilter";
|
||||
import { MCPServer } from "@/lib/tools/interfaces";
|
||||
@@ -81,6 +83,10 @@ interface ChatPreferencesFormValues {
|
||||
maximum_chat_retention_days: string;
|
||||
anonymous_user_enabled: boolean;
|
||||
disable_default_assistant: boolean;
|
||||
|
||||
// File limits
|
||||
user_file_max_upload_size_mb: string;
|
||||
file_token_count_threshold_k: string;
|
||||
}
|
||||
|
||||
interface MCPServerCardTool {
|
||||
@@ -185,6 +191,173 @@ function MCPServerCard({
|
||||
);
|
||||
}
|
||||
|
||||
type FileLimitFieldName =
|
||||
| "user_file_max_upload_size_mb"
|
||||
| "file_token_count_threshold_k";
|
||||
|
||||
interface NumericLimitFieldProps {
|
||||
name: FileLimitFieldName;
|
||||
defaultValue: string;
|
||||
saveSettings: (updates: Partial<Settings>) => Promise<void>;
|
||||
maxValue?: number;
|
||||
allowZero?: boolean;
|
||||
}
|
||||
|
||||
function NumericLimitField({
|
||||
name,
|
||||
defaultValue,
|
||||
saveSettings,
|
||||
maxValue,
|
||||
allowZero = false,
|
||||
}: NumericLimitFieldProps) {
|
||||
const { values, setFieldValue } =
|
||||
useFormikContext<ChatPreferencesFormValues>();
|
||||
const initialValue = useRef(values[name]);
|
||||
const restoringRef = useRef(false);
|
||||
const value = values[name];
|
||||
|
||||
const parsed = parseInt(value, 10);
|
||||
const isOverMax =
|
||||
maxValue !== undefined && !isNaN(parsed) && parsed > maxValue;
|
||||
|
||||
const handleRestore = () => {
|
||||
restoringRef.current = true;
|
||||
initialValue.current = defaultValue;
|
||||
void setFieldValue(name, defaultValue);
|
||||
void saveSettings({ [name]: parseInt(defaultValue, 10) });
|
||||
};
|
||||
|
||||
const handleBlur = () => {
|
||||
// The restore button triggers a blur — skip since handleRestore already saved.
|
||||
if (restoringRef.current) {
|
||||
restoringRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = parseInt(value, 10);
|
||||
const isValid = !isNaN(parsed) && (allowZero ? parsed >= 0 : parsed > 0);
|
||||
|
||||
// Revert invalid input (empty, NaN, negative).
|
||||
if (!isValid) {
|
||||
if (allowZero) {
|
||||
// Empty/invalid means "no limit" — persist 0 and clear the field.
|
||||
void setFieldValue(name, "");
|
||||
void saveSettings({ [name]: 0 });
|
||||
initialValue.current = "";
|
||||
} else {
|
||||
void setFieldValue(name, initialValue.current);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Block save when the value exceeds the hard ceiling.
|
||||
if (maxValue !== undefined && parsed > maxValue) {
|
||||
return;
|
||||
}
|
||||
|
||||
// For allowZero fields, 0 means "no limit" — clear the display
|
||||
// so the "No limit" placeholder is visible, but still persist 0.
|
||||
if (allowZero && parsed === 0) {
|
||||
void setFieldValue(name, "");
|
||||
if (initialValue.current !== "") {
|
||||
void saveSettings({ [name]: 0 });
|
||||
initialValue.current = "";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedDisplay = String(parsed);
|
||||
|
||||
// Update the display to the canonical form (e.g. strip leading zeros).
|
||||
if (value !== normalizedDisplay) {
|
||||
void setFieldValue(name, normalizedDisplay);
|
||||
}
|
||||
|
||||
// Persist only when the value actually changed.
|
||||
if (normalizedDisplay !== initialValue.current) {
|
||||
void saveSettings({ [name]: parsed });
|
||||
initialValue.current = normalizedDisplay;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="group w-full">
|
||||
<InputTypeInField
|
||||
name={name}
|
||||
inputMode="numeric"
|
||||
showClearButton={false}
|
||||
pattern="[0-9]*"
|
||||
placeholder={allowZero ? "No limit" : `Default: ${defaultValue}`}
|
||||
variant={isOverMax ? "error" : undefined}
|
||||
rightSection={
|
||||
(value || "") !== defaultValue ? (
|
||||
<div className="opacity-0 group-hover:opacity-100 group-focus-within:opacity-100 transition-opacity">
|
||||
<IconButton
|
||||
icon={SvgRefreshCw}
|
||||
tooltip="Restore default"
|
||||
internal
|
||||
type="button"
|
||||
onClick={handleRestore}
|
||||
/>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
onBlur={handleBlur}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface FileSizeLimitFieldsProps {
|
||||
saveSettings: (updates: Partial<Settings>) => Promise<void>;
|
||||
defaultUploadSizeMb: string;
|
||||
defaultTokenThresholdK: string;
|
||||
maxAllowedUploadSizeMb?: number;
|
||||
}
|
||||
|
||||
function FileSizeLimitFields({
|
||||
saveSettings,
|
||||
defaultUploadSizeMb,
|
||||
defaultTokenThresholdK,
|
||||
maxAllowedUploadSizeMb,
|
||||
}: FileSizeLimitFieldsProps) {
|
||||
return (
|
||||
<div className="flex gap-4 w-full items-start">
|
||||
<div className="flex-1">
|
||||
<InputLayouts.Vertical
|
||||
title="File Size Limit (MB)"
|
||||
subDescription={
|
||||
maxAllowedUploadSizeMb
|
||||
? `Max: ${maxAllowedUploadSizeMb} MB`
|
||||
: undefined
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<NumericLimitField
|
||||
name="user_file_max_upload_size_mb"
|
||||
defaultValue={defaultUploadSizeMb}
|
||||
saveSettings={saveSettings}
|
||||
maxValue={maxAllowedUploadSizeMb}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<InputLayouts.Vertical
|
||||
title="File Token Limit (thousand tokens)"
|
||||
nonInteractive
|
||||
>
|
||||
<NumericLimitField
|
||||
name="file_token_count_threshold_k"
|
||||
defaultValue={defaultTokenThresholdK}
|
||||
saveSettings={saveSettings}
|
||||
allowZero
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Inner form component that uses useFormikContext to access values
|
||||
* and create save handlers for settings fields.
|
||||
@@ -201,6 +374,7 @@ function ChatPreferencesForm() {
|
||||
// Tools availability
|
||||
const { tools: availableTools } = useAvailableTools();
|
||||
const vectorDbEnabled = useVectorDbEnabled();
|
||||
|
||||
const searchTool = availableTools.find(
|
||||
(t) => t.in_code_tool_id === SEARCH_TOOL_ID
|
||||
);
|
||||
@@ -723,6 +897,28 @@ function ChatPreferencesForm() {
|
||||
</InputLayouts.Horizontal>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<InputLayouts.Vertical
|
||||
title="File Attachment Size Limit"
|
||||
description="Files attached in chats and projects must fit within both limits to be accepted. Larger files increase latency, memory usage, and token costs."
|
||||
>
|
||||
<FileSizeLimitFields
|
||||
saveSettings={saveSettings}
|
||||
defaultUploadSizeMb={
|
||||
settings?.settings.default_user_file_max_upload_size_mb?.toString() ??
|
||||
"100"
|
||||
}
|
||||
defaultTokenThresholdK={
|
||||
settings?.settings.default_file_token_count_threshold_k?.toString() ??
|
||||
"200"
|
||||
}
|
||||
maxAllowedUploadSizeMb={
|
||||
settings?.settings.max_allowed_upload_size_mb
|
||||
}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Allow Anonymous Users"
|
||||
@@ -862,6 +1058,21 @@ export default function ChatPreferencesPage() {
|
||||
anonymous_user_enabled: settings.settings.anonymous_user_enabled ?? false,
|
||||
disable_default_assistant:
|
||||
settings.settings.disable_default_assistant ?? false,
|
||||
|
||||
// File limits — for upload size: 0/null means "use default";
|
||||
// for token threshold: null means "use default", 0 means "no limit".
|
||||
user_file_max_upload_size_mb:
|
||||
(settings.settings.user_file_max_upload_size_mb ?? 0) <= 0
|
||||
? settings.settings.default_user_file_max_upload_size_mb?.toString() ??
|
||||
"100"
|
||||
: settings.settings.user_file_max_upload_size_mb!.toString(),
|
||||
file_token_count_threshold_k:
|
||||
settings.settings.file_token_count_threshold_k == null
|
||||
? settings.settings.default_file_token_count_threshold_k?.toString() ??
|
||||
"200"
|
||||
: settings.settings.file_token_count_threshold_k === 0
|
||||
? ""
|
||||
: settings.settings.file_token_count_threshold_k.toString(),
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
420
web/src/refresh-pages/admin/HooksPage/ConnectedHookCard.tsx
Normal file
420
web/src/refresh-pages/admin/HooksPage/ConnectedHookCard.tsx
Normal file
@@ -0,0 +1,420 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import {
|
||||
SvgCheckCircle,
|
||||
SvgExternalLink,
|
||||
SvgPlug,
|
||||
SvgRefreshCw,
|
||||
SvgSettings,
|
||||
SvgTrash,
|
||||
SvgUnplug,
|
||||
} from "@opal/icons";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import type {
|
||||
HookPointMeta,
|
||||
HookResponse,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
import {
|
||||
activateHook,
|
||||
deactivateHook,
|
||||
deleteHook,
|
||||
validateHook,
|
||||
} from "@/refresh-pages/admin/HooksPage/svc";
|
||||
import { getHookPointIcon } from "@/refresh-pages/admin/HooksPage/hookPointIcons";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sub-component: disconnect confirmation modal
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface DisconnectConfirmModalProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
hook: HookResponse;
|
||||
onDisconnect: () => void;
|
||||
onDisconnectAndDelete: () => void;
|
||||
}
|
||||
|
||||
function DisconnectConfirmModal({
|
||||
open,
|
||||
onOpenChange,
|
||||
hook,
|
||||
onDisconnect,
|
||||
onDisconnectAndDelete,
|
||||
}: DisconnectConfirmModalProps) {
|
||||
return (
|
||||
<Modal open={open} onOpenChange={onOpenChange}>
|
||||
<Modal.Content width="md" height="fit">
|
||||
<Modal.Header
|
||||
icon={(props) => (
|
||||
<SvgUnplug {...props} className="text-action-danger-05" />
|
||||
)}
|
||||
title={`Disconnect ${hook.name}`}
|
||||
onClose={() => onOpenChange(false)}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text mainUiBody text03>
|
||||
Onyx will stop calling this endpoint for hook{" "}
|
||||
<strong>
|
||||
<em>{hook.name}</em>
|
||||
</strong>
|
||||
. In-flight requests will continue to run. The external endpoint
|
||||
may still retain data previously sent to it. You can reconnect
|
||||
this hook later if needed.
|
||||
</Text>
|
||||
<Text mainUiBody text03>
|
||||
You can also delete this hook. Deletion cannot be undone.
|
||||
</Text>
|
||||
</div>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
cancel={
|
||||
<Button
|
||||
prominence="secondary"
|
||||
onClick={() => onOpenChange(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
}
|
||||
submit={
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
variant="danger"
|
||||
prominence="secondary"
|
||||
onClick={onDisconnectAndDelete}
|
||||
>
|
||||
Disconnect & Delete
|
||||
</Button>
|
||||
<Button
|
||||
variant="danger"
|
||||
prominence="primary"
|
||||
onClick={onDisconnect}
|
||||
>
|
||||
Disconnect
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sub-component: delete confirmation modal
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface DeleteConfirmModalProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
hook: HookResponse;
|
||||
onDelete: () => void;
|
||||
}
|
||||
|
||||
function DeleteConfirmModal({
|
||||
open,
|
||||
onOpenChange,
|
||||
hook,
|
||||
onDelete,
|
||||
}: DeleteConfirmModalProps) {
|
||||
return (
|
||||
<Modal open={open} onOpenChange={onOpenChange}>
|
||||
<Modal.Content width="md" height="fit">
|
||||
<Modal.Header
|
||||
icon={(props) => (
|
||||
<SvgTrash {...props} className="text-action-danger-05" />
|
||||
)}
|
||||
title={`Delete ${hook.name}`}
|
||||
onClose={() => onOpenChange(false)}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text mainUiBody text03>
|
||||
Hook{" "}
|
||||
<strong>
|
||||
<em>{hook.name}</em>
|
||||
</strong>{" "}
|
||||
will be permanently removed from this hook point. The external
|
||||
endpoint may still retain data previously sent to it.
|
||||
</Text>
|
||||
<Text mainUiBody text03>
|
||||
Deletion cannot be undone.
|
||||
</Text>
|
||||
</div>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
cancel={
|
||||
<Button
|
||||
prominence="secondary"
|
||||
onClick={() => onOpenChange(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
}
|
||||
submit={
|
||||
<Button variant="danger" prominence="primary" onClick={onDelete}>
|
||||
Delete
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ConnectedHookCard
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export interface ConnectedHookCardProps {
|
||||
hook: HookResponse;
|
||||
spec: HookPointMeta | undefined;
|
||||
onEdit: () => void;
|
||||
onDeleted: () => void;
|
||||
onToggled: (updated: HookResponse) => void;
|
||||
}
|
||||
|
||||
export default function ConnectedHookCard({
|
||||
hook,
|
||||
spec,
|
||||
onEdit,
|
||||
onDeleted,
|
||||
onToggled,
|
||||
}: ConnectedHookCardProps) {
|
||||
const [isBusy, setIsBusy] = useState(false);
|
||||
const [disconnectConfirmOpen, setDisconnectConfirmOpen] = useState(false);
|
||||
const [deleteConfirmOpen, setDeleteConfirmOpen] = useState(false);
|
||||
|
||||
async function handleDelete() {
|
||||
setDeleteConfirmOpen(false);
|
||||
setIsBusy(true);
|
||||
try {
|
||||
await deleteHook(hook.id);
|
||||
onDeleted();
|
||||
} catch (err) {
|
||||
console.error("Failed to delete hook:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to delete hook."
|
||||
);
|
||||
} finally {
|
||||
setIsBusy(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleActivate() {
|
||||
setIsBusy(true);
|
||||
try {
|
||||
const updated = await activateHook(hook.id);
|
||||
onToggled(updated);
|
||||
} catch (err) {
|
||||
console.error("Failed to reconnect hook:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to reconnect hook."
|
||||
);
|
||||
} finally {
|
||||
setIsBusy(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDeactivate() {
|
||||
setDisconnectConfirmOpen(false);
|
||||
setIsBusy(true);
|
||||
try {
|
||||
const updated = await deactivateHook(hook.id);
|
||||
onToggled(updated);
|
||||
} catch (err) {
|
||||
console.error("Failed to deactivate hook:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to deactivate hook."
|
||||
);
|
||||
} finally {
|
||||
setIsBusy(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDisconnectAndDelete() {
|
||||
setDisconnectConfirmOpen(false);
|
||||
setIsBusy(true);
|
||||
try {
|
||||
const deactivated = await deactivateHook(hook.id);
|
||||
onToggled(deactivated);
|
||||
await deleteHook(hook.id);
|
||||
onDeleted();
|
||||
} catch (err) {
|
||||
console.error("Failed to disconnect hook:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to disconnect hook."
|
||||
);
|
||||
} finally {
|
||||
setIsBusy(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleValidate() {
|
||||
setIsBusy(true);
|
||||
try {
|
||||
const result = await validateHook(hook.id);
|
||||
if (result.status === "passed") {
|
||||
toast.success("Hook validated successfully.");
|
||||
} else {
|
||||
toast.error(
|
||||
result.error_message ?? `Validation failed: ${result.status}`
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to validate hook:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to validate hook."
|
||||
);
|
||||
} finally {
|
||||
setIsBusy(false);
|
||||
}
|
||||
}
|
||||
|
||||
const HookIcon = getHookPointIcon(hook.hook_point);
|
||||
|
||||
return (
|
||||
<>
|
||||
<DisconnectConfirmModal
|
||||
open={disconnectConfirmOpen}
|
||||
onOpenChange={setDisconnectConfirmOpen}
|
||||
hook={hook}
|
||||
onDisconnect={handleDeactivate}
|
||||
onDisconnectAndDelete={handleDisconnectAndDelete}
|
||||
/>
|
||||
<DeleteConfirmModal
|
||||
open={deleteConfirmOpen}
|
||||
onOpenChange={setDeleteConfirmOpen}
|
||||
hook={hook}
|
||||
onDelete={handleDelete}
|
||||
/>
|
||||
<Card
|
||||
variant="primary"
|
||||
padding={0.5}
|
||||
gap={0}
|
||||
className={cn(
|
||||
"hover:border-border-02",
|
||||
!hook.is_active && "!bg-background-neutral-02"
|
||||
)}
|
||||
>
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="sm"
|
||||
icon={HookIcon}
|
||||
title={hook.name}
|
||||
titleClassName={!hook.is_active ? "line-through" : undefined}
|
||||
iconClassName="text-text-04"
|
||||
description={`Hook Point: ${spec?.display_name ?? hook.hook_point}`}
|
||||
bottomChildren={
|
||||
spec?.docs_url ? (
|
||||
<a
|
||||
href={spec.docs_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="flex items-center gap-1 w-fit font-secondary-body text-text-03"
|
||||
>
|
||||
<span className="underline">Documentation</span>
|
||||
<SvgExternalLink size={12} className="shrink-0" />
|
||||
</a>
|
||||
) : undefined
|
||||
}
|
||||
rightChildren={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
alignItems="end"
|
||||
width="fit"
|
||||
height="fit"
|
||||
gap={0}
|
||||
>
|
||||
<div className="flex items-center gap-1 p-2">
|
||||
{hook.is_active ? (
|
||||
<>
|
||||
<Text mainUiAction text03>
|
||||
Connected
|
||||
</Text>
|
||||
<SvgCheckCircle
|
||||
size={16}
|
||||
className="text-status-success-05"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center gap-1",
|
||||
isBusy
|
||||
? "opacity-50 pointer-events-none"
|
||||
: "cursor-pointer"
|
||||
)}
|
||||
onClick={handleActivate}
|
||||
>
|
||||
<Text mainUiAction text03>
|
||||
Reconnect
|
||||
</Text>
|
||||
<SvgPlug size={16} className="text-text-03 shrink-0" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<Disabled disabled={isBusy}>
|
||||
{/* Plain div instead of Section: Section applies style={{ padding }} inline which
|
||||
overrides Tailwind padding classes, making per-side padding (pl/pr/pb) ineffective. */}
|
||||
<div className="flex items-center gap-0.5 pl-1 pr-1 pb-1">
|
||||
{hook.is_active ? (
|
||||
<>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgUnplug}
|
||||
onClick={() => setDisconnectConfirmOpen(true)}
|
||||
tooltip="Disconnect Hook"
|
||||
aria-label="Deactivate hook"
|
||||
/>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgRefreshCw}
|
||||
onClick={handleValidate}
|
||||
tooltip="Test Connection"
|
||||
aria-label="Re-validate hook"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgTrash}
|
||||
onClick={() => setDeleteConfirmOpen(true)}
|
||||
tooltip="Delete"
|
||||
aria-label="Delete hook"
|
||||
/>
|
||||
)}
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgSettings}
|
||||
onClick={onEdit}
|
||||
tooltip="Manage"
|
||||
aria-label="Configure hook"
|
||||
/>
|
||||
</div>
|
||||
</Disabled>
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
</Card>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,117 +1,211 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useState } from "react";
|
||||
import { useHookSpecs } from "@/hooks/useHookSpecs";
|
||||
import { useHooks } from "@/hooks/useHooks";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import InputSearch from "@/refresh-components/inputs/InputSearch";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgBubbleText,
|
||||
SvgExternalLink,
|
||||
SvgFileBroadcast,
|
||||
SvgHookNodes,
|
||||
} from "@opal/icons";
|
||||
import { IconFunctionComponent } from "@opal/types";
|
||||
import { SvgArrowExchange, SvgExternalLink } from "@opal/icons";
|
||||
import HookFormModal from "@/refresh-pages/admin/HooksPage/HookFormModal";
|
||||
import ConnectedHookCard from "@/refresh-pages/admin/HooksPage/ConnectedHookCard";
|
||||
import { getHookPointIcon } from "@/refresh-pages/admin/HooksPage/hookPointIcons";
|
||||
import type {
|
||||
HookPointMeta,
|
||||
HookResponse,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
const HOOK_POINT_ICONS: Record<string, IconFunctionComponent> = {
|
||||
document_ingestion: SvgFileBroadcast,
|
||||
query_processing: SvgBubbleText,
|
||||
};
|
||||
|
||||
function getHookPointIcon(hookPoint: string): IconFunctionComponent {
|
||||
return HOOK_POINT_ICONS[hookPoint] ?? SvgHookNodes;
|
||||
}
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function HooksContent() {
|
||||
const [search, setSearch] = useState("");
|
||||
const [connectSpec, setConnectSpec] = useState<HookPointMeta | null>(null);
|
||||
const [editHook, setEditHook] = useState<HookResponse | null>(null);
|
||||
|
||||
const { specs, isLoading, error } = useHookSpecs();
|
||||
const { specs, isLoading: specsLoading, error: specsError } = useHookSpecs();
|
||||
const {
|
||||
hooks,
|
||||
isLoading: hooksLoading,
|
||||
error: hooksError,
|
||||
mutate,
|
||||
} = useHooks();
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
toast.error("Failed to load hook specifications.");
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
if (isLoading) {
|
||||
if (specsLoading || hooksLoading) {
|
||||
return <SimpleLoader />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
if (specsError || hooksError) {
|
||||
return (
|
||||
<Text text03 secondaryBody>
|
||||
Failed to load hook specifications. Please refresh the page.
|
||||
Failed to load{specsError ? " hook specifications" : " hooks"}. Please
|
||||
refresh the page.
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
const filtered = (specs ?? []).filter(
|
||||
(spec) =>
|
||||
spec.display_name.toLowerCase().includes(search.toLowerCase()) ||
|
||||
spec.description.toLowerCase().includes(search.toLowerCase())
|
||||
);
|
||||
const hooksByPoint: Record<string, HookResponse[]> = {};
|
||||
for (const hook of hooks ?? []) {
|
||||
(hooksByPoint[hook.hook_point] ??= []).push(hook);
|
||||
}
|
||||
|
||||
const searchLower = search.toLowerCase();
|
||||
|
||||
// Connected hooks sorted alphabetically by hook name
|
||||
const connectedHooks = (hooks ?? [])
|
||||
.filter(
|
||||
(hook) =>
|
||||
!searchLower ||
|
||||
hook.name.toLowerCase().includes(searchLower) ||
|
||||
specs
|
||||
?.find((s) => s.hook_point === hook.hook_point)
|
||||
?.display_name.toLowerCase()
|
||||
.includes(searchLower)
|
||||
)
|
||||
.sort((a, b) => a.name.localeCompare(b.name));
|
||||
|
||||
// Unconnected hook point specs sorted alphabetically
|
||||
const unconnectedSpecs = (specs ?? [])
|
||||
.filter(
|
||||
(spec) =>
|
||||
(hooksByPoint[spec.hook_point]?.length ?? 0) === 0 &&
|
||||
(!searchLower ||
|
||||
spec.display_name.toLowerCase().includes(searchLower) ||
|
||||
spec.description.toLowerCase().includes(searchLower))
|
||||
)
|
||||
.sort((a, b) => a.display_name.localeCompare(b.display_name));
|
||||
|
||||
function handleHookSuccess(updated: HookResponse) {
|
||||
mutate((prev) => {
|
||||
if (!prev) return [updated];
|
||||
const idx = prev.findIndex((h) => h.id === updated.id);
|
||||
if (idx >= 0) {
|
||||
const next = [...prev];
|
||||
next[idx] = updated;
|
||||
return next;
|
||||
}
|
||||
return [...prev, updated];
|
||||
});
|
||||
}
|
||||
|
||||
function handleHookDeleted(id: number) {
|
||||
mutate((prev) => prev?.filter((h) => h.id !== id));
|
||||
}
|
||||
|
||||
const connectSpec_ =
|
||||
connectSpec ??
|
||||
(editHook
|
||||
? specs?.find((s) => s.hook_point === editHook.hook_point)
|
||||
: undefined);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<InputSearch
|
||||
placeholder="Search hooks..."
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
<>
|
||||
<div className="flex flex-col gap-6">
|
||||
<InputSearch
|
||||
placeholder="Search hooks..."
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
/>
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
{connectedHooks.length === 0 && unconnectedSpecs.length === 0 ? (
|
||||
<Text text03 secondaryBody>
|
||||
{search
|
||||
? "No hooks match your search."
|
||||
: "No hook points are available."}
|
||||
</Text>
|
||||
) : (
|
||||
<>
|
||||
{connectedHooks.map((hook) => {
|
||||
const spec = specs?.find(
|
||||
(s) => s.hook_point === hook.hook_point
|
||||
);
|
||||
return (
|
||||
<ConnectedHookCard
|
||||
key={hook.id}
|
||||
hook={hook}
|
||||
spec={spec}
|
||||
onEdit={() => setEditHook(hook)}
|
||||
onDeleted={() => handleHookDeleted(hook.id)}
|
||||
onToggled={handleHookSuccess}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{unconnectedSpecs.map((spec) => {
|
||||
const UnconnectedIcon = getHookPointIcon(spec.hook_point);
|
||||
return (
|
||||
<Card
|
||||
key={spec.hook_point}
|
||||
variant="secondary"
|
||||
padding={0.5}
|
||||
gap={0}
|
||||
className="hover:border-border-02"
|
||||
>
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="sm"
|
||||
icon={UnconnectedIcon}
|
||||
title={spec.display_name}
|
||||
iconClassName="text-text-04"
|
||||
description={spec.description}
|
||||
bottomChildren={
|
||||
spec.docs_url ? (
|
||||
<a
|
||||
href={spec.docs_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="flex items-center gap-1 w-fit font-secondary-body text-text-03"
|
||||
>
|
||||
<span className="underline">Documentation</span>
|
||||
<SvgExternalLink size={12} className="shrink-0" />
|
||||
</a>
|
||||
) : undefined
|
||||
}
|
||||
rightChildren={
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgArrowExchange}
|
||||
onClick={() => setConnectSpec(spec)}
|
||||
>
|
||||
Connect
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Create modal */}
|
||||
<HookFormModal
|
||||
key={connectSpec?.hook_point ?? "create"}
|
||||
open={!!connectSpec}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) setConnectSpec(null);
|
||||
}}
|
||||
spec={connectSpec ?? undefined}
|
||||
onSuccess={handleHookSuccess}
|
||||
/>
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
{filtered.length === 0 ? (
|
||||
<Text text03 secondaryBody>
|
||||
{search
|
||||
? "No hooks match your search."
|
||||
: "No hook points are available."}
|
||||
</Text>
|
||||
) : (
|
||||
filtered.map((spec) => (
|
||||
<Card
|
||||
key={spec.hook_point}
|
||||
variant="secondary"
|
||||
padding={0.5}
|
||||
gap={0}
|
||||
>
|
||||
<ContentAction
|
||||
icon={getHookPointIcon(spec.hook_point)}
|
||||
title={spec.display_name}
|
||||
description={spec.description}
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
paddingVariant="fit"
|
||||
rightChildren={
|
||||
// TODO(Bo-Onyx): wire up Connect — open modal to create/edit hook
|
||||
<Button prominence="tertiary" rightIcon={SvgArrowExchange}>
|
||||
Connect
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
{spec.docs_url && (
|
||||
<div className="pl-7 pt-1">
|
||||
<a
|
||||
href={spec.docs_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="flex items-center gap-1 w-fit text-text-03"
|
||||
>
|
||||
<Text as="span" secondaryBody text03 className="underline">
|
||||
Documentation
|
||||
</Text>
|
||||
<SvgExternalLink size={16} className="text-text-02" />
|
||||
</a>
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{/* Edit modal */}
|
||||
<HookFormModal
|
||||
key={editHook?.id ?? "edit"}
|
||||
open={!!editHook}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) setEditHook(null);
|
||||
}}
|
||||
hook={editHook ?? undefined}
|
||||
spec={connectSpec_ ?? undefined}
|
||||
onSuccess={handleHookSuccess}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
13
web/src/refresh-pages/admin/HooksPage/hookPointIcons.ts
Normal file
13
web/src/refresh-pages/admin/HooksPage/hookPointIcons.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import { SvgBubbleText, SvgFileBroadcast, SvgHookNodes } from "@opal/icons";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
|
||||
const HOOK_POINT_ICONS: Record<string, IconFunctionComponent> = {
|
||||
document_ingestion: SvgFileBroadcast,
|
||||
query_processing: SvgBubbleText,
|
||||
};
|
||||
|
||||
function getHookPointIcon(hookPoint: string): IconFunctionComponent {
|
||||
return HOOK_POINT_ICONS[hookPoint] ?? SvgHookNodes;
|
||||
}
|
||||
|
||||
export { HOOK_POINT_ICONS, getHookPointIcon };
|
||||
Reference in New Issue
Block a user