mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-21 09:26:44 +00:00
Compare commits
3 Commits
v3.1.5
...
refactor/r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c079cd867a | ||
|
|
f1c1473903 | ||
|
|
4a67cc0a09 |
@@ -1,35 +0,0 @@
|
||||
"""remove voice_provider deleted column
|
||||
|
||||
Revision ID: 1d78c0ca7853
|
||||
Revises: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-26 11:30:53.883127
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d78c0ca7853"
|
||||
down_revision = "a3f8b2c1d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Hard-delete any soft-deleted rows before dropping the column
|
||||
op.execute("DELETE FROM voice_provider WHERE deleted = true")
|
||||
op.drop_column("voice_provider", "deleted")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"voice_provider",
|
||||
sa.Column(
|
||||
"deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
@@ -28,7 +28,6 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -188,6 +187,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -227,7 +227,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_permission_sync_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -29,7 +29,6 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -163,6 +162,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -221,7 +221,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -11,8 +11,6 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -24,7 +22,6 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -62,25 +59,9 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -219,37 +200,12 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -8,59 +7,7 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
_broker_client: Redis | None = None
|
||||
_broker_url: str | None = None
|
||||
_broker_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def celery_get_broker_client(app: Celery) -> Redis:
|
||||
"""Return a shared Redis client connected to the Celery broker DB.
|
||||
|
||||
Uses a module-level singleton so all tasks on a worker share one
|
||||
connection instead of creating a new one per call. The client
|
||||
connects directly to the broker Redis DB (parsed from the broker URL).
|
||||
|
||||
Thread-safe via lock — safe for use in Celery thread-pool workers.
|
||||
|
||||
Usage:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
length = celery_get_queue_length(queue, r_celery)
|
||||
"""
|
||||
global _broker_client, _broker_url
|
||||
with _broker_client_lock:
|
||||
url = app.conf.broker_url
|
||||
if _broker_client is not None and _broker_url == url:
|
||||
try:
|
||||
_broker_client.ping()
|
||||
return _broker_client
|
||||
except Exception:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
elif _broker_client is not None:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
|
||||
_broker_url = url
|
||||
_broker_client = Redis.from_url(
|
||||
url,
|
||||
decode_responses=False,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _broker_client
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
|
||||
@@ -14,7 +14,6 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -133,6 +132,7 @@ def revoke_tasks_blocking_deletion(
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -149,7 +149,6 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -450,7 +449,7 @@ def check_indexing_completion(
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = celery_get_broker_client(task.app)
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
@@ -18,7 +19,6 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -698,27 +698,31 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
|
||||
# Collect queue metrics with broker connection
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_metrics = _collect_queue_metrics(r_celery)
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
|
||||
# Collect remaining metrics (no broker connection needed)
|
||||
# Collect and log each metric
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
all_metrics: list[Metric] = queue_metrics
|
||||
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
|
||||
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
for metric in all_metrics:
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -886,7 +890,7 @@ def monitor_celery_queues_helper(
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = celery_get_broker_client(task.app)
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -1076,7 +1080,7 @@ def cloud_monitor_celery_pidbox(
|
||||
num_deleted = 0
|
||||
|
||||
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -17,7 +17,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -204,6 +203,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -261,7 +261,6 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
@@ -16,7 +16,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
@@ -106,7 +105,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery = celery_get_broker_client(celery_app)
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
@@ -239,7 +238,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -592,7 +591,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
|
||||
@@ -816,29 +816,6 @@ MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
# Maximum embedded images allowed in a single file. PDFs (and other formats)
|
||||
# with thousands of embedded images can OOM the user-file-processing worker
|
||||
# because every image is decoded with PIL and then sent to the vision LLM.
|
||||
# Enforced both at upload time (rejects the file) and during extraction
|
||||
# (defense-in-depth: caps the number of images materialized).
|
||||
#
|
||||
# Clamped to >= 0; a negative env value would turn upload validation into
|
||||
# always-fail and extraction into always-stop, which is never desired. 0
|
||||
# disables image extraction entirely, which is a valid (if aggressive) setting.
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
|
||||
)
|
||||
|
||||
# Maximum embedded images allowed across all files in a single upload batch.
|
||||
# Protects against the scenario where a user uploads many files that each
|
||||
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
|
||||
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
|
||||
# worker under concurrency or run up surprise latency/cost. Also clamped
|
||||
# to >= 0.
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
|
||||
)
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
|
||||
@@ -12,11 +12,6 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
|
||||
MASK_CREDENTIAL_CHAR = "\u2022"
|
||||
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
|
||||
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
|
||||
@@ -10,7 +10,6 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
@@ -240,53 +239,29 @@ def enhanced_search_ids(
|
||||
)
|
||||
|
||||
|
||||
def _bulk_fetch_request(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
|
||||
f"be decoded as JSON (response too large or truncated)."
|
||||
)
|
||||
raise
|
||||
|
||||
mid = len(issue_ids) // 2
|
||||
logger.warning(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
raise e
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in raw_issues
|
||||
for issue in response["issues"]
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -7,14 +6,6 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -50,6 +49,7 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,6 +58,7 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -420,94 +421,6 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -519,6 +432,7 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -748,6 +662,7 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -780,37 +695,62 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
"""Build the thread text from messages."""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
for msg in replies:
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -1036,16 +976,7 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -1062,16 +993,8 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1087,7 +1010,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
)
|
||||
]
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -639,38 +638,12 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
) -> list[str]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -695,15 +668,6 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -720,9 +684,7 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -740,8 +702,7 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
search_queries = [
|
||||
return [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
@@ -54,6 +55,7 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -97,6 +98,11 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -131,47 +132,32 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = list(result.scalars().all())
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
return list(chat_sessions)
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
|
||||
@@ -8,8 +8,6 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -47,23 +45,6 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
|
||||
return fetch_all_federated_connectors(db_session)
|
||||
|
||||
|
||||
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
|
||||
"""Raise if any credential string value contains mask placeholder characters.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
Both must be rejected.
|
||||
"""
|
||||
for key, val in credentials.items():
|
||||
if isinstance(val, str) and (
|
||||
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
|
||||
)
|
||||
|
||||
|
||||
def validate_federated_connector_credentials(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -85,8 +66,6 @@ def create_federated_connector(
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> FederatedConnector:
|
||||
"""Create a new federated connector with credential and config validation."""
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before creating
|
||||
if not validate_federated_connector_credentials(source, credentials):
|
||||
raise ValueError(
|
||||
@@ -298,8 +277,6 @@ def update_federated_connector(
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before updating
|
||||
if not validate_federated_connector_credentials(
|
||||
federated_connector.source, credentials
|
||||
|
||||
@@ -3135,6 +3135,8 @@ class VoiceProvider(Base):
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.pat import build_displayable_pat
|
||||
@@ -46,6 +47,7 @@ async def fetch_user_for_pat(
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now)
|
||||
)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -229,9 +229,7 @@ def get_memories_for_user(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> Sequence[Memory]:
|
||||
return db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
|
||||
).all()
|
||||
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
|
||||
|
||||
|
||||
def update_user_pinned_assistants(
|
||||
|
||||
@@ -17,30 +17,39 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +58,9 @@ def fetch_voice_provider_by_type(
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
@@ -108,10 +119,10 @@ def upsert_voice_provider(
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Delete a voice provider by ID."""
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
db_session.delete(provider)
|
||||
provider.deleted = True
|
||||
db_session.flush()
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import chardet
|
||||
import openpyxl
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -50,21 +49,9 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
|
||||
def get_markitdown_converter() -> "MarkItDown":
|
||||
global _MARKITDOWN_CONVERTER
|
||||
from markitdown import MarkItDown
|
||||
|
||||
if _MARKITDOWN_CONVERTER is None:
|
||||
from markitdown import MarkItDown
|
||||
|
||||
# Patch this function to effectively no-op because we were seeing this
|
||||
# module take an inordinate amount of time to convert charts to markdown,
|
||||
# making some powerpoint files with many or complicated charts nearly
|
||||
# unindexable.
|
||||
from markitdown.converters._pptx_converter import PptxConverter
|
||||
|
||||
setattr(
|
||||
PptxConverter,
|
||||
"_convert_chart_to_markdown",
|
||||
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
|
||||
)
|
||||
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
|
||||
return _MARKITDOWN_CONVERTER
|
||||
|
||||
@@ -189,56 +176,6 @@ def read_text_file(
|
||||
return file_content_raw, metadata
|
||||
|
||||
|
||||
def count_pdf_embedded_images(file: IO[Any], cap: int) -> int:
|
||||
"""Return the number of embedded images in a PDF, short-circuiting at cap+1.
|
||||
|
||||
Used to reject PDFs whose image count would OOM the user-file-processing
|
||||
worker during indexing. Returns a value > cap as a sentinel once the count
|
||||
exceeds the cap, so callers do not iterate thousands of image objects just
|
||||
to report a number. Returns 0 if the PDF cannot be parsed.
|
||||
|
||||
Owner-password-only PDFs (permission restrictions but no open password) are
|
||||
counted normally — they decrypt with an empty string. Truly password-locked
|
||||
PDFs are skipped (return 0) since we can't inspect them; the caller should
|
||||
ensure the password-protected check runs first.
|
||||
|
||||
Always restores the file pointer to its original position before returning.
|
||||
"""
|
||||
from pypdf import PdfReader
|
||||
|
||||
try:
|
||||
start_pos = file.tell()
|
||||
except Exception:
|
||||
start_pos = None
|
||||
try:
|
||||
if start_pos is not None:
|
||||
file.seek(0)
|
||||
reader = PdfReader(file)
|
||||
if reader.is_encrypted:
|
||||
# Try empty password first (owner-password-only PDFs); give up if that fails.
|
||||
try:
|
||||
if reader.decrypt("") == 0:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
count = 0
|
||||
for page in reader.pages:
|
||||
for _ in page.images:
|
||||
count += 1
|
||||
if count > cap:
|
||||
return count
|
||||
return count
|
||||
except Exception:
|
||||
logger.warning("Failed to count embedded images in PDF", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
if start_pos is not None:
|
||||
try:
|
||||
file.seek(start_pos)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
"""
|
||||
Extract text from a PDF. For embedded images, a more complex approach is needed.
|
||||
@@ -265,26 +202,18 @@ def read_pdf_file(
|
||||
try:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
if pdf_reader.is_encrypted:
|
||||
# Try the explicit password first, then fall back to an empty
|
||||
# string. Owner-password-only PDFs (permission restrictions but
|
||||
# no open password) decrypt successfully with "".
|
||||
# See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
passwords = [p for p in [pdf_pass, ""] if p is not None]
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
decrypt_success = False
|
||||
for pw in passwords:
|
||||
try:
|
||||
if pdf_reader.decrypt(pw) != 0:
|
||||
decrypt_success = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error("Unable to decrypt pdf")
|
||||
|
||||
if not decrypt_success:
|
||||
logger.error(
|
||||
"Encrypted PDF could not be decrypted, returning empty text."
|
||||
)
|
||||
return "", metadata, []
|
||||
elif pdf_reader.is_encrypted:
|
||||
logger.warning("No Password for an encrypted PDF, returning empty text.")
|
||||
return "", metadata, []
|
||||
|
||||
# Basic PDF metadata
|
||||
if pdf_reader.metadata is not None:
|
||||
@@ -302,27 +231,8 @@ def read_pdf_file(
|
||||
)
|
||||
|
||||
if extract_images:
|
||||
image_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
images_processed = 0
|
||||
cap_reached = False
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
if cap_reached:
|
||||
break
|
||||
for image_file_object in page.images:
|
||||
if images_processed >= image_cap:
|
||||
# Defense-in-depth backstop. Upload-time validation
|
||||
# should have rejected files exceeding the cap, but
|
||||
# we also break here so a single oversized file can
|
||||
# never pin a worker.
|
||||
logger.warning(
|
||||
"PDF embedded image cap reached (%d). "
|
||||
"Skipping remaining images on page %d and beyond.",
|
||||
image_cap,
|
||||
page_num + 1,
|
||||
)
|
||||
cap_reached = True
|
||||
break
|
||||
|
||||
image = Image.open(io.BytesIO(image_file_object.data))
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format=image.format)
|
||||
@@ -335,7 +245,6 @@ def read_pdf_file(
|
||||
image_callback(img_bytes, image_name)
|
||||
else:
|
||||
extracted_images.append((img_bytes, image_name))
|
||||
images_processed += 1
|
||||
|
||||
return text, metadata, extracted_images
|
||||
|
||||
|
||||
@@ -33,20 +33,8 @@ def is_pdf_protected(file: IO[Any]) -> bool:
|
||||
|
||||
with preserve_position(file):
|
||||
reader = PdfReader(file)
|
||||
if not reader.is_encrypted:
|
||||
return False
|
||||
|
||||
# PDFs with only an owner password (permission restrictions like
|
||||
# print/copy disabled) use an empty user password — any viewer can open
|
||||
# them without prompting. decrypt("") returns 0 only when a real user
|
||||
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
try:
|
||||
return reader.decrypt("") == 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to evaluate PDF encryption; treating as password protected"
|
||||
)
|
||||
return True
|
||||
return bool(reader.is_encrypted)
|
||||
|
||||
|
||||
def is_docx_protected(file: IO[Any]) -> bool:
|
||||
|
||||
@@ -26,7 +26,6 @@ class LlmProviderNames(str, Enum):
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -47,7 +46,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
]
|
||||
|
||||
|
||||
@@ -66,7 +64,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -119,7 +116,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -175,28 +175,6 @@ def _strip_tool_content_from_messages(
|
||||
return result
|
||||
|
||||
|
||||
def _fix_tool_user_message_ordering(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Insert a synthetic assistant message between tool and user messages.
|
||||
|
||||
Some models (e.g. Mistral on Azure) require strict message ordering where
|
||||
a user message cannot immediately follow a tool message. This function
|
||||
inserts a minimal assistant message to bridge the gap.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
result: list[dict[str, Any]] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev_role = result[-1].get("role")
|
||||
curr_role = msg.get("role")
|
||||
if prev_role == "tool" and curr_role == "user":
|
||||
result.append({"role": "assistant", "content": "Noted. Continuing."})
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
@@ -207,21 +185,6 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
|
||||
"""Check if the prompt contains any assistant messages with tool_calls.
|
||||
|
||||
When Anthropic's extended thinking is enabled, the API requires every
|
||||
assistant message to start with a thinking block before any tool_use
|
||||
blocks. Since we don't preserve thinking_blocks (they carry
|
||||
cryptographic signatures that can't be reconstructed), we must skip
|
||||
the thinking param whenever history contains prior tool-calling turns.
|
||||
"""
|
||||
from onyx.llm.models import AssistantMessage
|
||||
|
||||
msgs = prompt if isinstance(prompt, list) else [prompt]
|
||||
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -327,19 +290,12 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
|
||||
# model names directly to the endpoint. We route through LiteLLM's
|
||||
# openai provider with the server's base URL, and ensure /v1 is appended.
|
||||
if model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
):
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
self._custom_llm_provider = "openai"
|
||||
# LiteLLM's OpenAI client requires an api_key to be set.
|
||||
# Many OpenAI-compatible servers don't need auth, so supply a
|
||||
# placeholder to prevent LiteLLM from raising AuthenticationError.
|
||||
if not self._api_key:
|
||||
model_kwargs.setdefault("api_key", "not-needed")
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
@@ -456,20 +412,17 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_openai_compatible_proxy = self._model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
)
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
if is_openai_compatible_proxy:
|
||||
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
|
||||
# servers) expect model names sent directly to their endpoint.
|
||||
# We use custom_llm_provider="openai" so LiteLLM doesn't try
|
||||
# to route based on the provider prefix.
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
@@ -513,20 +466,7 @@ class LitellmLLM(LLM):
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
# across turns, so skip thinking when the history already
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if budget_tokens is not None:
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
@@ -560,10 +500,7 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if (
|
||||
not (is_claude_model or is_ollama or is_mistral)
|
||||
or is_openai_compatible_proxy
|
||||
):
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included unless the request is
|
||||
@@ -611,18 +548,6 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Some models (e.g. Mistral) reject a user message
|
||||
# immediately after a tool message. Insert a synthetic
|
||||
# assistant bridge message to satisfy the ordering
|
||||
# constraint. Check both the provider and the deployment/
|
||||
# model name to catch Mistral hosted on Azure.
|
||||
model_or_deployment = (
|
||||
self._deployment_name or self._model_version or ""
|
||||
).lower()
|
||||
is_mistral_model = is_mistral or "mistral" in model_or_deployment
|
||||
if is_mistral_model:
|
||||
messages = _fix_tool_user_message_ordering(messages)
|
||||
|
||||
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
|
||||
# reject requests where tool_choice is explicitly null.
|
||||
if tools and tool_choice is not None:
|
||||
|
||||
@@ -15,8 +15,6 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
BIFROST_PROVIDER_NAME = "bifrost"
|
||||
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
@@ -52,7 +51,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
|
||||
}
|
||||
|
||||
|
||||
@@ -338,7 +336,6 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
|
||||
from onyx.configs.app_configs import MCP_SERVER_HOST
|
||||
from onyx.configs.app_configs import MCP_SERVER_PORT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -17,7 +16,6 @@ def main() -> None:
|
||||
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
|
||||
return
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
|
||||
|
||||
from onyx.mcp_server.api import mcp_app
|
||||
|
||||
@@ -40,8 +40,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -52,9 +50,6 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES
|
||||
@@ -132,49 +127,6 @@ class DeleteFileResponse(BaseModel):
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _looks_like_pdf(filename: str, content_type: str | None) -> bool:
|
||||
"""True if either the filename or the content-type indicates a PDF.
|
||||
|
||||
Client-supplied ``content_type`` can be spoofed (e.g. a PDF uploaded with
|
||||
``Content-Type: application/octet-stream``), so we also fall back to
|
||||
extension-based detection via ``mimetypes.guess_type`` on the filename.
|
||||
"""
|
||||
if content_type == "application/pdf":
|
||||
return True
|
||||
guessed, _ = mimetypes.guess_type(filename)
|
||||
return guessed == "application/pdf"
|
||||
|
||||
|
||||
def _check_pdf_image_caps(
|
||||
filename: str, content: bytes, content_type: str | None, batch_total: int
|
||||
) -> int:
|
||||
"""Enforce per-file and per-batch embedded-image caps for PDFs.
|
||||
|
||||
Returns the number of embedded images in this file (0 for non-PDFs) so
|
||||
callers can update their running batch total. Raises OnyxError(INVALID_INPUT)
|
||||
if either cap is exceeded.
|
||||
"""
|
||||
if not _looks_like_pdf(filename, content_type):
|
||||
return 0
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Short-circuit at the larger cap so we get a useful count for both checks.
|
||||
count = count_pdf_embedded_images(BytesIO(content), max(file_cap, batch_cap))
|
||||
if count > file_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"PDF '{filename}' contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting the document into smaller files.",
|
||||
)
|
||||
if batch_total + count > batch_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Upload would exceed the {batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading fewer image-heavy files at once.",
|
||||
)
|
||||
return count
|
||||
|
||||
|
||||
def _sanitize_path(path: str) -> str:
|
||||
"""Sanitize a file path, removing traversal attempts and normalizing.
|
||||
|
||||
@@ -403,7 +355,6 @@ async def upload_files(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Sanitize the base path
|
||||
@@ -423,14 +374,6 @@ async def upload_files(
|
||||
detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB",
|
||||
)
|
||||
|
||||
# Reject PDFs with an unreasonable per-file or per-batch image count
|
||||
batch_image_total += _check_pdf_image_caps(
|
||||
filename=file.filename or "unnamed",
|
||||
content=content,
|
||||
content_type=file.content_type,
|
||||
batch_total=batch_image_total,
|
||||
)
|
||||
|
||||
# Validate cumulative storage (existing + this upload batch)
|
||||
total_size += file_size
|
||||
if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES:
|
||||
@@ -529,7 +472,6 @@ async def upload_zip(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
|
||||
# Extract zip contents into a subfolder named after the zip file
|
||||
zip_name = api_sanitize_filename(file.filename or "upload")
|
||||
@@ -568,36 +510,6 @@ async def upload_zip(
|
||||
logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size")
|
||||
continue
|
||||
|
||||
# Skip PDFs that would trip the per-file or per-batch image
|
||||
# cap (would OOM the user-file-processing worker). Matches
|
||||
# /upload behavior but uses skip-and-warn to stay consistent
|
||||
# with the zip path's handling of oversized files.
|
||||
zip_file_name = zip_info.filename.split("/")[-1]
|
||||
zip_content_type, _ = mimetypes.guess_type(zip_file_name)
|
||||
if zip_content_type == "application/pdf":
|
||||
image_count = count_pdf_embedded_images(
|
||||
BytesIO(file_content),
|
||||
max(
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
),
|
||||
)
|
||||
if image_count > MAX_EMBEDDED_IMAGES_PER_FILE:
|
||||
logger.warning(
|
||||
"Skipping '%s' - exceeds %d per-file embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
)
|
||||
continue
|
||||
if batch_image_total + image_count > MAX_EMBEDDED_IMAGES_PER_UPLOAD:
|
||||
logger.warning(
|
||||
"Skipping '%s' - would exceed %d per-batch embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
)
|
||||
continue
|
||||
batch_image_total += image_count
|
||||
|
||||
total_size += file_size
|
||||
|
||||
# Validate cumulative storage
|
||||
|
||||
@@ -44,12 +44,11 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -142,11 +141,19 @@ def _validate_endpoint(
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# Any timeout (connect, read, or write) means the configured timeout_seconds
|
||||
# is too low for this endpoint. Report as timeout so the UI directs the user
|
||||
# to increase the timeout setting.
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
logger.warning(
|
||||
"Hook endpoint validation: timeout for %s",
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
@@ -10,12 +10,9 @@ from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -195,11 +192,6 @@ def categorize_uploaded_files(
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to get current tenant ID: {str(e)}")
|
||||
|
||||
# Running total of embedded images across PDFs in this batch. Once the
|
||||
# aggregate cap is reached, subsequent PDFs in the same upload are
|
||||
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
|
||||
batch_image_total = 0
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
@@ -261,47 +253,6 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject PDFs with an unreasonable number of embedded images
|
||||
# (either per-file or accumulated across this upload batch).
|
||||
# A PDF with thousands of embedded images can OOM the
|
||||
# user-file-processing celery worker because every image is
|
||||
# decoded with PIL and then sent to the vision LLM.
|
||||
if extension == ".pdf":
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Use the larger of the two caps as the short-circuit
|
||||
# threshold so we get a useful count for both checks.
|
||||
# count_pdf_embedded_images restores the stream position.
|
||||
count = count_pdf_embedded_images(
|
||||
upload.file, max(file_cap, batch_cap)
|
||||
)
|
||||
if count > file_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"PDF contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting "
|
||||
f"the document into smaller files."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
if batch_image_total + count > batch_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"Upload would exceed the "
|
||||
f"{batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading "
|
||||
f"fewer image-heavy files at once."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
batch_image_total += count
|
||||
|
||||
text_content = extract_file_text(
|
||||
file=upload.file,
|
||||
file_name=filename,
|
||||
|
||||
@@ -74,8 +74,6 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenAICompatibleFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenAICompatibleModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
@@ -1526,7 +1524,6 @@ def get_bifrost_available_models(
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1577,95 +1574,3 @@ def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> d
|
||||
source_name="Bifrost",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/openai-compatible/available-models")
|
||||
def get_openai_compatible_server_available_models(
|
||||
request: OpenAICompatibleModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your OpenAI-compatible endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenAICompatibleFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
OpenAICompatibleFinalModelResponse(
|
||||
name=model_id,
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse OpenAI-compatible model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from OpenAI-compatible endpoint",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenAI Compatible",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openai_compatible_server_response(
|
||||
api_base: str, api_key: str | None = None
|
||||
) -> dict:
|
||||
"""Perform GET to an OpenAI-compatible /v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
# Ensure we hit /v1/models
|
||||
if cleaned_api_base.endswith("/v1"):
|
||||
url = f"{cleaned_api_base}/models"
|
||||
else:
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="OpenAI Compatible",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
@@ -463,19 +463,3 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
display_name: str # Human-readable name from Bifrost API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
# OpenAI Compatible dynamic models fetch
|
||||
class OpenAICompatibleModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class OpenAICompatibleFinalModelResponse(BaseModel):
|
||||
name: str # Model ID (e.g. "meta-llama/Llama-3-8B-Instruct")
|
||||
display_name: str # Human-readable name from API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -26,7 +26,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -147,7 +147,6 @@ class UserInfo(BaseModel):
|
||||
is_anonymous_user: bool | None = None,
|
||||
tenant_info: TenantInfo | None = None,
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
|
||||
memories: list[MemoryItem] | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -192,7 +191,10 @@ class UserInfo(BaseModel):
|
||||
role=user.personal_role or "",
|
||||
use_memories=user.use_memories,
|
||||
enable_memory_tool=user.enable_memory_tool,
|
||||
memories=memories or [],
|
||||
memories=[
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in (user.memories or [])
|
||||
],
|
||||
user_preferences=user.user_preferences or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -57,7 +57,6 @@ from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.user_preferences import deactivate_user
|
||||
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
|
||||
from onyx.db.user_preferences import get_latest_access_token_for_user
|
||||
from onyx.db.user_preferences import get_memories_for_user
|
||||
from onyx.db.user_preferences import update_assistant_preferences
|
||||
from onyx.db.user_preferences import update_user_assistant_visibility
|
||||
from onyx.db.user_preferences import update_user_auto_scroll
|
||||
@@ -824,11 +823,6 @@ def verify_user_logged_in(
|
||||
[],
|
||||
),
|
||||
)
|
||||
memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
]
|
||||
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
@@ -839,7 +833,6 @@ def verify_user_logged_in(
|
||||
new_tenant=new_tenant,
|
||||
invitation=tenant_invitation,
|
||||
),
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
return user_info
|
||||
@@ -937,8 +930,7 @@ def update_user_personalization_api(
|
||||
else user.enable_memory_tool
|
||||
)
|
||||
existing_memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
|
||||
]
|
||||
new_memories = (
|
||||
request.memories if request.memories is not None else existing_memories
|
||||
|
||||
@@ -12,6 +12,7 @@ stale, which is fine for monitoring dashboards.
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -103,23 +104,25 @@ class _CachedCollector(Collector):
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape."""
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._celery_app is None:
|
||||
if self._get_redis is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
redis_client = self._get_redis()
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
@@ -401,19 +404,17 @@ class RedisHealthCollector(_CachedCollector):
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._celery_app is None:
|
||||
if self._get_redis is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
redis_client = self._get_redis()
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
|
||||
@@ -3,8 +3,12 @@
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
@@ -17,7 +21,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_celery_app() is called. Keeping them at
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
@@ -28,15 +32,72 @@ _worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
|
||||
@@ -129,10 +129,6 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -88,22 +88,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
|
||||
Also patches ``celery_get_broker_client`` so the mock app doesn't need
|
||||
a real broker URL.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -90,17 +90,8 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -103,11 +103,6 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="no yuhong allowed",
|
||||
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.db.federated import _reject_masked_credentials
|
||||
|
||||
|
||||
class TestRejectMaskedCredentials:
|
||||
"""Verify that masked credential values are never accepted for DB writes.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
_reject_masked_credentials must catch both.
|
||||
"""
|
||||
|
||||
def test_rejects_fully_masked_value(self) -> None:
|
||||
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": masked})
|
||||
|
||||
def test_rejects_long_string_masked_value(self) -> None:
|
||||
"""mask_string returns 'first4...last4' for long strings — the real
|
||||
format used for OAuth credentials like client_id and client_secret."""
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": "1234...7890"})
|
||||
|
||||
def test_rejects_when_any_field_is_masked(self) -> None:
|
||||
"""Even if client_id is real, a masked client_secret must be caught."""
|
||||
with pytest.raises(ValueError, match="client_secret"):
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": MASK_CREDENTIAL_CHAR * 12,
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_real_credentials(self) -> None:
|
||||
# Should not raise
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": "test_client_secret_value",
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_empty_dict(self) -> None:
|
||||
# Should not raise — empty credentials are handled elsewhere
|
||||
_reject_masked_credentials({})
|
||||
|
||||
def test_ignores_non_string_values(self) -> None:
|
||||
# Non-string values (None, bool, int) should pass through
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "real_value",
|
||||
"redirect_uri": None,
|
||||
"some_flag": True,
|
||||
}
|
||||
)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""Tests for celery_get_broker_client singleton."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery import celery_redis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> Iterator[None]:
|
||||
"""Reset the module-level singleton between tests."""
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
yield
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
|
||||
|
||||
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
|
||||
app = MagicMock()
|
||||
app.conf.broker_url = broker_url
|
||||
return app
|
||||
|
||||
|
||||
class TestCeleryGetBrokerClient:
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
result = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert result is mock_client
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://localhost:6379/15"
|
||||
assert call_args[1]["decode_responses"] is False
|
||||
assert call_args[1]["socket_keepalive"] is True
|
||||
assert call_args[1]["retry_on_timeout"] is True
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert client1 is client2
|
||||
# from_url called only once
|
||||
assert mock_redis_cls.from_url.call_count == 1
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
|
||||
stale_client = MagicMock()
|
||||
stale_client.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
fresh_client = MagicMock()
|
||||
fresh_client.ping.return_value = True
|
||||
|
||||
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
|
||||
|
||||
app = _make_mock_app()
|
||||
|
||||
# First call creates stale_client
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
assert client1 is stale_client
|
||||
|
||||
# Second call: ping fails, creates fresh_client
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
assert client2 is fresh_client
|
||||
assert mock_redis_cls.from_url.call_count == 2
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_redis_cls.from_url.return_value = MagicMock()
|
||||
|
||||
app = _make_mock_app("redis://custom-host:6380/3")
|
||||
celery_redis.celery_get_broker_client(app)
|
||||
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://custom-host:6380/3"
|
||||
@@ -1,147 +0,0 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": issue_id,
|
||||
"key": f"TEST-{issue_id}",
|
||||
"fields": {"summary": f"Issue {issue_id}"},
|
||||
}
|
||||
|
||||
|
||||
def _mock_jira_client() -> MagicMock:
|
||||
mock = MagicMock(spec=JIRA)
|
||||
mock._options = {"server": "https://jira.example.com"}
|
||||
mock._session = MagicMock()
|
||||
mock._get_url = MagicMock(
|
||||
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def test_bulk_fetch_success() -> None:
|
||||
"""Happy path: all issues fetched in one request."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3"])
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(r, Issue) for r in result)
|
||||
client._session.post.assert_called_once()
|
||||
|
||||
|
||||
def test_bulk_fetch_splits_on_json_error() -> None:
|
||||
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if len(ids) > 2:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 2294125
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
|
||||
assert len(result) == 4
|
||||
returned_ids = {r.raw["id"] for r in result}
|
||||
assert returned_ids == {"1", "2", "3", "4"}
|
||||
assert call_count > 1
|
||||
|
||||
|
||||
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
|
||||
"""A single issue that always fails JSON decode raises after splitting."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if "bad" in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 100
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "bad", "2"])
|
||||
|
||||
|
||||
def test_bulk_fetch_non_json_error_propagates() -> None:
|
||||
"""Non-JSONDecodeError exceptions still propagate."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = ValueError("something else broke")
|
||||
client._session.post.return_value = resp
|
||||
|
||||
try:
|
||||
bulk_fetch_issues(client, ["1"])
|
||||
assert False, "Expected ValueError to propagate"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_bulk_fetch_with_fields() -> None:
|
||||
"""Fields parameter is forwarded correctly."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
bulk_fetch_issues(client, ["1"], fields="summary,description")
|
||||
|
||||
call_payload = client._session.post.call_args[1]["json"]
|
||||
assert call_payload["fields"] == ["summary", "description"]
|
||||
|
||||
|
||||
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
|
||||
client = _mock_jira_client()
|
||||
bad_id = "BAD"
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if bad_id in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"truncated", "doc", 999
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
@@ -1,67 +0,0 @@
|
||||
"""Tests for _build_thread_text function."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.slack_search import _build_thread_text
|
||||
|
||||
|
||||
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
|
||||
return {"user": user, "text": text, "ts": ts}
|
||||
|
||||
|
||||
class TestBuildThreadText:
|
||||
"""Verify _build_thread_text includes full thread replies up to cap."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
|
||||
"""All replies within cap are included in output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "parent msg", "1000.0"),
|
||||
_make_msg("U2", "reply 1", "1001.0"),
|
||||
_make_msg("U3", "reply 2", "1002.0"),
|
||||
_make_msg("U4", "reply 3", "1003.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "parent msg" in result
|
||||
assert "reply 1" in result
|
||||
assert "reply 2" in result
|
||||
assert "reply 3" in result
|
||||
assert "..." not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
|
||||
"""Single message (no replies) returns just the parent text."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [_make_msg("U1", "just a message", "1000.0")]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "just a message" in result
|
||||
assert "Replies:" not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
|
||||
"""Thread parent message is always the first line of output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "I am the parent", "1000.0"),
|
||||
_make_msg("U2", "I am a reply", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
parent_pos = result.index("I am the parent")
|
||||
reply_pos = result.index("I am a reply")
|
||||
assert parent_pos < reply_pos
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
|
||||
"""User IDs in thread text are replaced with display names."""
|
||||
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
|
||||
messages = [
|
||||
_make_msg("U1", "hello", "1000.0"),
|
||||
_make_msg("U2", "world", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "<@U1>" not in result
|
||||
assert "<@U2>" not in result
|
||||
@@ -1,108 +0,0 @@
|
||||
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
|
||||
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
|
||||
|
||||
|
||||
class TestExtractSlackMessageUrls:
|
||||
"""Verify URL parsing extracts channel_id and timestamp correctly."""
|
||||
|
||||
def test_standard_url(self) -> None:
|
||||
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 1
|
||||
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
|
||||
|
||||
def test_multiple_urls(self) -> None:
|
||||
query = (
|
||||
"compare https://co.slack.com/archives/C111/p1234567890123456 "
|
||||
"and https://co.slack.com/archives/C222/p9876543210987654"
|
||||
)
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 2
|
||||
assert results[0] == ("C111", "1234567890.123456")
|
||||
assert results[1] == ("C222", "9876543210.987654")
|
||||
|
||||
def test_no_urls(self) -> None:
|
||||
query = "what happened in #general last week?"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_non_slack_url_ignored(self) -> None:
|
||||
query = "check https://google.com/archives/C111/p1234567890123456"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_timestamp_conversion(self) -> None:
|
||||
"""p prefix removed, dot inserted after 10th digit."""
|
||||
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
channel_id, ts = results[0]
|
||||
assert channel_id == "CABC123"
|
||||
assert ts == "1775491616.524769"
|
||||
assert not ts.startswith("p")
|
||||
assert "." in ts
|
||||
|
||||
|
||||
class TestFetchThreadFromUrl:
|
||||
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search._build_thread_text")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_successful_fetch(
|
||||
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
|
||||
# Mock conversations_replies
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = [
|
||||
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
|
||||
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
|
||||
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
|
||||
]
|
||||
mock_client.conversations_replies.return_value = mock_response
|
||||
|
||||
# Mock channel info
|
||||
mock_ch_response = MagicMock()
|
||||
mock_ch_response.get.return_value = {"name": "general"}
|
||||
mock_client.conversations_info.return_value = mock_ch_response
|
||||
|
||||
mock_build_thread.return_value = (
|
||||
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(
|
||||
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
|
||||
)
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
|
||||
assert len(result.messages) == 1
|
||||
msg = result.messages[0]
|
||||
assert msg.channel_id == "C097NBWMY8Y"
|
||||
assert msg.thread_id is None # Prevents double-enrichment
|
||||
assert msg.slack_score == 100000.0
|
||||
assert "parent" in msg.text
|
||||
mock_client.conversations_replies.assert_called_once_with(
|
||||
channel="C097NBWMY8Y", ts="1775491616.524769"
|
||||
)
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
mock_client.conversations_replies.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response=MagicMock(status_code=404),
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
assert len(result.messages) == 0
|
||||
@@ -1,225 +0,0 @@
|
||||
"""Tests for get_chat_sessions_by_user filtering behavior.
|
||||
|
||||
Verifies that failed chat sessions (those with only SYSTEM messages) are
|
||||
correctly filtered out while preserving recently created sessions, matching
|
||||
the behavior specified in PR #7233.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def _make_session(
|
||||
user_id: UUID,
|
||||
time_created: datetime | None = None,
|
||||
time_updated: datetime | None = None,
|
||||
description: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock ChatSession with the given attributes."""
|
||||
session = MagicMock(spec=ChatSession)
|
||||
session.id = uuid4()
|
||||
session.user_id = user_id
|
||||
session.time_created = time_created or datetime.now(timezone.utc)
|
||||
session.time_updated = time_updated or session.time_created
|
||||
session.description = description
|
||||
session.deleted = False
|
||||
session.onyxbot_flow = False
|
||||
session.project_id = None
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id() -> UUID:
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_time() -> datetime:
|
||||
"""A timestamp well outside the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recent_time() -> datetime:
|
||||
"""A timestamp within the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(minutes=2)
|
||||
|
||||
|
||||
class TestGetChatSessionsByUser:
|
||||
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
|
||||
|
||||
def test_filters_out_failed_sessions(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Sessions with only SYSTEM messages should be excluded."""
|
||||
valid_session = _make_session(user_id, time_created=old_time)
|
||||
failed_session = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
# First execute: returns all sessions
|
||||
# Second execute: returns only the valid session's ID (has non-system msgs)
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
valid_session,
|
||||
failed_session,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == valid_session.id
|
||||
|
||||
def test_keeps_recent_sessions_without_messages(
|
||||
self, user_id: UUID, recent_time: datetime
|
||||
) -> None:
|
||||
"""Recently created sessions should be kept even without messages."""
|
||||
recent_session = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [recent_session]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == recent_session.id
|
||||
# Should only have been called once — no second query needed
|
||||
# because the recent session is within the leeway window
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_include_failed_chats_skips_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""When include_failed_chats=True, no filtering should occur."""
|
||||
session_a = _make_session(user_id, time_created=old_time)
|
||||
session_b = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=True,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Only one DB call — no second query for message validation
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_limit_applied_after_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Limit should be applied after filtering, not before."""
|
||||
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
|
||||
valid_ids = [s.id for s in sessions[:3]]
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = sessions
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = valid_ids
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Should be the first 2 valid sessions (order preserved)
|
||||
assert result[0].id == sessions[0].id
|
||||
assert result[1].id == sessions[1].id
|
||||
|
||||
def test_mixed_recent_and_old_sessions(
|
||||
self, user_id: UUID, old_time: datetime, recent_time: datetime
|
||||
) -> None:
|
||||
"""Mix of recent and old sessions should filter correctly."""
|
||||
old_valid = _make_session(user_id, time_created=old_time)
|
||||
old_failed = _make_session(user_id, time_created=old_time)
|
||||
recent_no_msgs = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
old_valid,
|
||||
old_failed,
|
||||
recent_no_msgs,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
result_ids = {cs.id for cs in result}
|
||||
assert old_valid.id in result_ids
|
||||
assert recent_no_msgs.id in result_ids
|
||||
assert old_failed.id not in result_ids
|
||||
|
||||
def test_empty_result(self, user_id: UUID) -> None:
|
||||
"""No sessions should return empty list without errors."""
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert db_session.execute.call_count == 1
|
||||
@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
mock_db_session.delete.assert_called_once_with(provider)
|
||||
assert provider.deleted is True
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Producer <1083d595b1>
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Count 1
|
||||
/Kids [ 4 0 R ]
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Contents 5 0 R
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
5 0 obj
|
||||
<<
|
||||
/Length 42
|
||||
>>
|
||||
stream
|
||||
,N<><6~<7E>)<29><><EFBFBD><EFBFBD><EFBFBD>u<EFBFBD><0C><><EFBFBD>Zc'<27><>>8g<38><67><EFBFBD>n<EFBFBD><6E><EFBFBD><EFBFBD><EFBFBD>9"
|
||||
endstream
|
||||
endobj
|
||||
6 0 obj
|
||||
<<
|
||||
/V 2
|
||||
/R 3
|
||||
/Length 128
|
||||
/P 4294967292
|
||||
/Filter /Standard
|
||||
/O <6a340a292629053da84a6d8b19a5d505953b8b3fdac3d2d389fde0e354528d44>
|
||||
/U <d6f0dc91c7b9de264a8d708515468e6528bf4e5e4e758a4164004e56fffa0108>
|
||||
>>
|
||||
endobj
|
||||
xref
|
||||
0 7
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000059 00000 n
|
||||
0000000118 00000 n
|
||||
0000000167 00000 n
|
||||
0000000348 00000 n
|
||||
0000000440 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 7
|
||||
/Root 3 0 R
|
||||
/Info 1 0 R
|
||||
/ID [ <6364336635356135633239323638353039306635656133623165313637366430> <6364336635356135633239323638353039306635656133623165313637366430> ]
|
||||
/Encrypt 6 0 R
|
||||
>>
|
||||
startxref
|
||||
655
|
||||
%%EOF
|
||||
@@ -12,10 +12,6 @@ dependency on pypdf internals (pypdf.generic).
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_processing import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import pdf_to_text
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.password_validation import is_pdf_protected
|
||||
@@ -58,12 +54,6 @@ class TestReadPdfFile:
|
||||
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
|
||||
assert text == ""
|
||||
|
||||
def test_owner_password_only_pdf_extracts_text(self) -> None:
|
||||
"""A PDF encrypted with only an owner password (no user password)
|
||||
should still yield its text content. Regression for #9754."""
|
||||
text, _, _ = read_pdf_file(_load("owner_protected.pdf"))
|
||||
assert "Hello World" in text
|
||||
|
||||
def test_empty_pdf(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("empty.pdf"))
|
||||
assert text.strip() == ""
|
||||
@@ -100,80 +90,6 @@ class TestReadPdfFile:
|
||||
# Returned list is empty when callback is used
|
||||
assert images == []
|
||||
|
||||
def test_image_cap_skips_images_above_limit(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""When the embedded-image cap is exceeded, remaining images are skipped.
|
||||
|
||||
The cap protects the user-file-processing worker from OOMing on PDFs
|
||||
with thousands of embedded images. Setting the cap to 0 should yield
|
||||
zero extracted images even though the fixture has one.
|
||||
"""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert images == []
|
||||
|
||||
def test_image_cap_at_limit_extracts_up_to_cap(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A cap >= image count behaves identically to the uncapped path."""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 100)
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert len(images) == 1
|
||||
|
||||
def test_image_cap_with_callback_stops_streaming_at_limit(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""The cap also short-circuits the streaming callback path."""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
|
||||
collected: list[tuple[bytes, str]] = []
|
||||
|
||||
def callback(data: bytes, name: str) -> None:
|
||||
collected.append((data, name))
|
||||
|
||||
read_pdf_file(
|
||||
_load("with_image.pdf"), extract_images=True, image_callback=callback
|
||||
)
|
||||
assert collected == []
|
||||
|
||||
|
||||
# ── count_pdf_embedded_images ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCountPdfEmbeddedImages:
|
||||
def test_returns_count_for_normal_pdf(self) -> None:
|
||||
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=10) == 1
|
||||
|
||||
def test_short_circuits_above_cap(self) -> None:
|
||||
# with_image.pdf has 1 image. cap=0 means "anything > 0 is over cap" —
|
||||
# function returns on first increment as the over-cap sentinel.
|
||||
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=0) == 1
|
||||
|
||||
def test_returns_zero_for_pdf_without_images(self) -> None:
|
||||
assert count_pdf_embedded_images(_load("simple.pdf"), cap=10) == 0
|
||||
|
||||
def test_returns_zero_for_invalid_pdf(self) -> None:
|
||||
assert count_pdf_embedded_images(BytesIO(b"not a pdf"), cap=10) == 0
|
||||
|
||||
def test_returns_zero_for_password_locked_pdf(self) -> None:
|
||||
# encrypted.pdf has an open password; we can't inspect without it, so
|
||||
# the helper returns 0 — callers rely on the password-protected check
|
||||
# that runs earlier in the upload pipeline.
|
||||
assert count_pdf_embedded_images(_load("encrypted.pdf"), cap=10) == 0
|
||||
|
||||
def test_inspects_owner_password_only_pdf(self) -> None:
|
||||
# owner_protected.pdf is encrypted but has no open password. It should
|
||||
# decrypt with an empty string and count images normally. The fixture
|
||||
# has zero images, so 0 is a real count (not the "bail on encrypted"
|
||||
# path).
|
||||
assert count_pdf_embedded_images(_load("owner_protected.pdf"), cap=10) == 0
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("with_image.pdf")
|
||||
pdf.seek(42)
|
||||
count_pdf_embedded_images(pdf, cap=10)
|
||||
assert pdf.tell() == 42
|
||||
|
||||
|
||||
# ── pdf_to_text ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -201,12 +117,6 @@ class TestIsPdfProtected:
|
||||
def test_protected_pdf(self) -> None:
|
||||
assert is_pdf_protected(_load("encrypted.pdf")) is True
|
||||
|
||||
def test_owner_password_only_is_not_protected(self) -> None:
|
||||
"""A PDF with only an owner password (permission restrictions) but no
|
||||
user password should NOT be considered protected — any viewer can open
|
||||
it without prompting for a password."""
|
||||
assert is_pdf_protected(_load("owner_protected.pdf")) is False
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("simple.pdf")
|
||||
pdf.seek(42)
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import io
|
||||
|
||||
from pptx import Presentation # type: ignore[import-untyped]
|
||||
from pptx.chart.data import CategoryChartData # type: ignore[import-untyped]
|
||||
from pptx.enum.chart import XL_CHART_TYPE # type: ignore[import-untyped]
|
||||
from pptx.util import Inches # type: ignore[import-untyped]
|
||||
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
|
||||
|
||||
def _make_pptx_with_chart() -> io.BytesIO:
|
||||
"""Create an in-memory pptx with one text slide and one chart slide."""
|
||||
prs = Presentation()
|
||||
|
||||
# Slide 1: text only
|
||||
slide1 = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide1.shapes.title.text = "Introduction"
|
||||
slide1.placeholders[1].text = "This is the first slide."
|
||||
|
||||
# Slide 2: chart
|
||||
slide2 = prs.slides.add_slide(prs.slide_layouts[5]) # Blank layout
|
||||
chart_data = CategoryChartData()
|
||||
chart_data.categories = ["Q1", "Q2", "Q3"]
|
||||
chart_data.add_series("Revenue", (100, 200, 300))
|
||||
slide2.shapes.add_chart(
|
||||
XL_CHART_TYPE.COLUMN_CLUSTERED,
|
||||
Inches(1),
|
||||
Inches(1),
|
||||
Inches(6),
|
||||
Inches(4),
|
||||
chart_data,
|
||||
)
|
||||
|
||||
buf = io.BytesIO()
|
||||
prs.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
def _make_pptx_without_chart() -> io.BytesIO:
|
||||
"""Create an in-memory pptx with a single text-only slide."""
|
||||
prs = Presentation()
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide.shapes.title.text = "Hello World"
|
||||
slide.placeholders[1].text = "Some content here."
|
||||
|
||||
buf = io.BytesIO()
|
||||
prs.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TestPptxToText:
|
||||
def test_chart_is_omitted(self) -> None:
|
||||
# Precondition
|
||||
pptx_file = _make_pptx_with_chart()
|
||||
|
||||
# Under test
|
||||
result = pptx_to_text(pptx_file)
|
||||
|
||||
# Postcondition
|
||||
assert "Introduction" in result
|
||||
assert "first slide" in result
|
||||
assert "[chart omitted]" in result
|
||||
# The actual chart data should NOT appear in the output.
|
||||
assert "Revenue" not in result
|
||||
assert "Q1" not in result
|
||||
|
||||
def test_text_only_pptx(self) -> None:
|
||||
# Precondition
|
||||
pptx_file = _make_pptx_without_chart()
|
||||
|
||||
# Under test
|
||||
result = pptx_to_text(pptx_file)
|
||||
|
||||
# Postcondition
|
||||
assert "Hello World" in result
|
||||
assert "Some content" in result
|
||||
assert "[chart omitted]" not in result
|
||||
@@ -11,7 +11,6 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import Function as LiteLLMFunction
|
||||
|
||||
import onyx.llm.models
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
@@ -1480,147 +1479,6 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
|
||||
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_true() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(name="get_weather", arguments="{}"),
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is True
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="Hello"),
|
||||
AssistantMessage(content="Hi there!"),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_user_only() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hello")]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
|
||||
"""When thinking is enabled but assistant messages with tool_calls lack
|
||||
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
|
||||
BadRequestError about missing thinking blocks."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Paris"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
onyx.llm.models.ToolMessage(
|
||||
content="22°C sunny",
|
||||
tool_call_id="tc_1",
|
||||
),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" not in kwargs, (
|
||||
"thinking param should be dropped when thinking_blocks are missing "
|
||||
"from assistant messages with tool_calls"
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
|
||||
"""When thinking is enabled and there are no historical assistant messages
|
||||
with tool_calls, the thinking param should be preserved."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" in kwargs, (
|
||||
"thinking param should be preserved when no assistant messages "
|
||||
"with tool_calls exist in history"
|
||||
)
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
|
||||
|
||||
def test_bifrost_claude_includes_allowed_openai_params() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,11 +158,13 @@ class TestValidateEndpoint:
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -10,9 +9,7 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _check_seat_availability
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -744,80 +741,3 @@ class TestEmailCasePreservation:
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
|
||||
class TestSeatLock:
|
||||
"""Tests for the advisory lock in _check_seat_availability."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
|
||||
def test_acquires_advisory_lock_before_checking(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The advisory lock must be acquired before the seat check runs."""
|
||||
call_order: list[str] = []
|
||||
|
||||
def track_execute(stmt: Any, _params: Any = None) -> None:
|
||||
if "pg_advisory_xact_lock" in str(stmt):
|
||||
call_order.append("lock")
|
||||
|
||||
mock_dal.session.execute.side_effect = track_execute
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
|
||||
) as mock_fetch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_fn = MagicMock(return_value=mock_result)
|
||||
mock_fetch.return_value = mock_fn
|
||||
|
||||
def track_check(*_args: Any, **_kwargs: Any) -> Any:
|
||||
call_order.append("check")
|
||||
return mock_result
|
||||
|
||||
mock_fn.side_effect = track_check
|
||||
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
assert call_order == ["lock", "check"]
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
|
||||
def test_lock_uses_tenant_scoped_key(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_check = MagicMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=mock_check,
|
||||
):
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
mock_dal.session.execute.assert_called_once()
|
||||
params = mock_dal.session.execute.call_args[0][1]
|
||||
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
|
||||
|
||||
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
|
||||
"""Lock id must be deterministic and differ across tenants."""
|
||||
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
|
||||
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
|
||||
|
||||
def test_no_lock_when_ee_absent(
|
||||
self,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""No advisory lock should be acquired when the EE check is absent."""
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=None,
|
||||
):
|
||||
result = _check_seat_availability(mock_dal)
|
||||
|
||||
assert result is None
|
||||
mock_dal.session.execute.assert_not_called()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -14,16 +13,6 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_broker_client() -> Iterator[None]:
|
||||
"""Patch celery_get_broker_client for all collector tests."""
|
||||
with patch(
|
||||
"onyx.background.celery.celery_redis.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
@@ -35,7 +24,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
collector.set_celery_app(MagicMock())
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -70,8 +60,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
@@ -84,8 +74,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -108,10 +98,31 @@ class TestQueueDepthCollector:
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
|
||||
@@ -1,22 +1,96 @@
|
||||
"""Tests for indexing pipeline setup."""
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
|
||||
|
||||
class TestCollectorCeleryAppSetup:
|
||||
def test_queue_depth_collector_uses_celery_app(self) -> None:
|
||||
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = QueueDepthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
|
||||
def test_redis_health_collector_uses_celery_app(self) -> None:
|
||||
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = RedisHealthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
|
||||
@@ -70,10 +70,6 @@ backend = [
|
||||
"lazy_imports==1.0.1",
|
||||
"lxml==5.3.0",
|
||||
"Mako==1.2.4",
|
||||
# NOTE: Do not update without understanding the patching behavior in
|
||||
# get_markitdown_converter in
|
||||
# backend/onyx/file_processing/extract_file_text.py and what impacts
|
||||
# updating might have on this behavior.
|
||||
"markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2",
|
||||
"mcp[cli]==1.26.0",
|
||||
"msal==1.34.0",
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 37 46"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={cn(className, "text-[#33C19E] dark:text-white")}
|
||||
{...props}
|
||||
>
|
||||
<title>Bifrost</title>
|
||||
<path
|
||||
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgBifrost;
|
||||
@@ -24,7 +24,6 @@ export { default as SvgAzure } from "@opal/icons/azure";
|
||||
export { default as SvgBarChart } from "@opal/icons/bar-chart";
|
||||
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
|
||||
export { default as SvgBell } from "@opal/icons/bell";
|
||||
export { default as SvgBifrost } from "@opal/icons/bifrost";
|
||||
export { default as SvgBlocks } from "@opal/icons/blocks";
|
||||
export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBookmark } from "@opal/icons/bookmark";
|
||||
|
||||
@@ -31,6 +31,7 @@ import { Credential } from "@/lib/connectors/credentials";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import SourceTile from "@/components/SourceTile";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import { createApiKey, updateApiKey } from "./lib";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Text } from "@opal/components";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { FormikField } from "@/refresh-components/form/FormikField";
|
||||
@@ -79,7 +79,7 @@ export default function OnyxApiKeyForm({
|
||||
{({ isSubmitting }) => (
|
||||
<Form className="w-full overflow-visible">
|
||||
<Modal.Body>
|
||||
<Text as="p">
|
||||
<Text as="p" color="text-05">
|
||||
Choose a memorable name for your API key. This is optional and
|
||||
can be added or changed later!
|
||||
</Text>
|
||||
|
||||
@@ -29,12 +29,10 @@ import {
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
|
||||
import { useBillingInformation } from "@/hooks/useBillingInformation";
|
||||
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTES.API_KEYS;
|
||||
@@ -47,11 +45,6 @@ function Main() {
|
||||
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const canCreateKeys = useCloudSubscription();
|
||||
const { data: billingData } = useBillingInformation();
|
||||
const isTrialing =
|
||||
billingData !== undefined &&
|
||||
hasActiveSubscription(billingData) &&
|
||||
billingData.status === BillingStatus.TRIALING;
|
||||
|
||||
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
|
||||
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
|
||||
@@ -83,16 +76,6 @@ function Main() {
|
||||
|
||||
const introSection = (
|
||||
<div className="flex flex-col items-start gap-4">
|
||||
{isTrialing && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
close={false}
|
||||
className="w-full"
|
||||
text="Upgrade to a paid plan to create API keys."
|
||||
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
)}
|
||||
<Text as="p">
|
||||
API Keys allow you to access Onyx APIs programmatically.
|
||||
{canCreateKeys
|
||||
@@ -103,9 +86,23 @@ function Main() {
|
||||
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
|
||||
Create API Key
|
||||
</CreateButton>
|
||||
) : isTrialing ? (
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
) : null}
|
||||
) : (
|
||||
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Text as="p" text04>
|
||||
Upgrade to a paid plan to create API keys.
|
||||
</Text>
|
||||
<Button
|
||||
variant="none"
|
||||
prominence="tertiary"
|
||||
size="2xs"
|
||||
icon={SvgInfo}
|
||||
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
</div>
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import Card from "@/refresh-components/cards/Card";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import InfoBlock from "@/refresh-components/messages/InfoBlock";
|
||||
|
||||
@@ -5,6 +5,7 @@ import { Section } from "@/layouts/general-layouts";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useState } from "react";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import InputFile from "@/refresh-components/inputs/InputFile";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
@@ -22,6 +22,7 @@ import type { IconProps } from "@opal/types";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
|
||||
@@ -1,387 +0,0 @@
|
||||
/**
|
||||
* Tests for BillingPage handleBillingReturn retry logic.
|
||||
*
|
||||
* The retry logic retries claimLicense up to 3 times with 2s backoff
|
||||
* when returning from a Stripe checkout session. This prevents the user
|
||||
* from getting stranded when the Stripe webhook fires concurrently with
|
||||
* the browser redirect and the license isn't ready yet.
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, waitFor } from "@tests/setup/test-utils";
|
||||
import { act } from "@testing-library/react";
|
||||
|
||||
// ---- Stable mock objects (must be named with mock* prefix for jest hoisting) ----
|
||||
// useRouter and useSearchParams must return the SAME reference each call, otherwise
|
||||
// React's useEffect sees them as changed and re-runs the effect on every render.
|
||||
const mockRouter = {
|
||||
replace: jest.fn() as jest.Mock,
|
||||
refresh: jest.fn() as jest.Mock,
|
||||
};
|
||||
const mockSearchParams = {
|
||||
get: jest.fn() as jest.Mock,
|
||||
};
|
||||
const mockClaimLicense = jest.fn() as jest.Mock;
|
||||
const mockRefreshBilling = jest.fn() as jest.Mock;
|
||||
const mockRefreshLicense = jest.fn() as jest.Mock;
|
||||
|
||||
// ---- Mocks ----
|
||||
|
||||
jest.mock("next/navigation", () => ({
|
||||
useRouter: () => mockRouter,
|
||||
useSearchParams: () => mockSearchParams,
|
||||
}));
|
||||
|
||||
jest.mock("@/layouts/settings-layouts", () => ({
|
||||
Root: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="settings-root">{children}</div>
|
||||
),
|
||||
Header: () => <div data-testid="settings-header" />,
|
||||
Body: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="settings-body">{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/layouts/general-layouts", () => ({
|
||||
Section: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@opal/icons", () => ({
|
||||
SvgArrowUpCircle: () => <svg />,
|
||||
SvgWallet: () => <svg />,
|
||||
}));
|
||||
|
||||
jest.mock("./PlansView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="plans-view" />,
|
||||
}));
|
||||
jest.mock("./CheckoutView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="checkout-view" />,
|
||||
}));
|
||||
jest.mock("./BillingDetailsView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="billing-details-view" />,
|
||||
}));
|
||||
jest.mock("./LicenseActivationCard", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="license-activation-card" />,
|
||||
}));
|
||||
|
||||
jest.mock("@/refresh-components/messages/Message", () => ({
|
||||
__esModule: true,
|
||||
default: ({
|
||||
text,
|
||||
description,
|
||||
onClose,
|
||||
}: {
|
||||
text: string;
|
||||
description?: string;
|
||||
onClose?: () => void;
|
||||
}) => (
|
||||
<div data-testid="activating-banner">
|
||||
<span data-testid="activating-banner-text">{text}</span>
|
||||
{description && (
|
||||
<span data-testid="activating-banner-description">{description}</span>
|
||||
)}
|
||||
{onClose && (
|
||||
<button data-testid="activating-banner-close" onClick={onClose}>
|
||||
Close
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/billing", () => ({
|
||||
useBillingInformation: jest.fn(),
|
||||
useLicense: jest.fn(),
|
||||
hasActiveSubscription: jest.fn().mockReturnValue(false),
|
||||
claimLicense: (...args: unknown[]) => mockClaimLicense(...args),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/constants", () => ({
|
||||
NEXT_PUBLIC_CLOUD_ENABLED: false,
|
||||
}));
|
||||
|
||||
// ---- Import after mocks ----
|
||||
import BillingPage from "./page";
|
||||
import { useBillingInformation, useLicense } from "@/lib/billing";
|
||||
|
||||
// ---- Test helpers ----
|
||||
|
||||
function setupHooks() {
|
||||
(useBillingInformation as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
refresh: mockRefreshBilling,
|
||||
});
|
||||
(useLicense as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
refresh: mockRefreshLicense,
|
||||
});
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
describe("BillingPage — handleBillingReturn retry logic", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
jest.useFakeTimers();
|
||||
setupHooks();
|
||||
// Default: no billing-return params
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
// Clear any activating state from prior tests
|
||||
sessionStorage.clear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
test("calls claimLicense once and refreshes on first-attempt success", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_test_123" : null
|
||||
);
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith("cs_test_123");
|
||||
});
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
// URL cleaned up after checkout return
|
||||
expect(mockRouter.replace).toHaveBeenCalledWith("/admin/billing", {
|
||||
scroll: false,
|
||||
});
|
||||
});
|
||||
|
||||
test("retries after first failure and succeeds on second attempt", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_retry_test" : null
|
||||
);
|
||||
mockClaimLicense
|
||||
.mockRejectedValueOnce(new Error("License not ready yet"))
|
||||
.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
// On eventual success, router and billing should be refreshed
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("retries all 3 times then navigates to details even on total failure", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_all_fail" : null
|
||||
);
|
||||
// All 3 attempts fail
|
||||
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
|
||||
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
// User stays on plans view with the activating banner
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("plans-view")).toBeInTheDocument();
|
||||
});
|
||||
// refreshBilling still fires so billing state is up to date
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
// Failure is logged
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Failed to sync license after billing return"),
|
||||
expect.any(Error)
|
||||
);
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("calls claimLicense without session_id on portal_return", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "portal_return" ? "true" : null
|
||||
);
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
|
||||
// No session_id for portal returns — called with undefined
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
|
||||
});
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("does not call claimLicense when no billing-return params present", async () => {
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(mockClaimLicense).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("shows activating banner and sets sessionStorage on 3x retry failure", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_all_fail" : null
|
||||
);
|
||||
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
|
||||
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.getByTestId("activating-banner-text")).toHaveTextContent(
|
||||
"Your license is still activating"
|
||||
);
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).not.toBeNull();
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("banner not rendered when no activating state", async () => {
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("banner shown on mount when sessionStorage key is set and not expired", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush React effects — banner is visible from lazy state init, no timer advancement needed
|
||||
await act(async () => {});
|
||||
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("banner not shown on mount when sessionStorage key is expired", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() - 1000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
test("poll calls claimLicense after 15s and clears banner on success", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
// Poll attempt succeeds
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush effects — banner visible from lazy state init
|
||||
await act(async () => {});
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
|
||||
// Advance past one poll interval (15s)
|
||||
await act(async () => {
|
||||
await jest.advanceTimersByTimeAsync(15_000);
|
||||
});
|
||||
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
expect(mockRefreshLicense).toHaveBeenCalled();
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("close button removes banner and clears sessionStorage", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush effects — banner visible from lazy state init
|
||||
await act(async () => {});
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
|
||||
const closeButton = screen.getByTestId("activating-banner-close");
|
||||
await act(async () => {
|
||||
closeButton.click();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -5,6 +5,7 @@ import { useSearchParams, useRouter } from "next/navigation";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgArrowUpCircle, SvgWallet } from "@opal/icons";
|
||||
import type { IconProps } from "@opal/types";
|
||||
@@ -17,7 +18,6 @@ import {
|
||||
} from "@/lib/billing";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
|
||||
import PlansView from "./PlansView";
|
||||
import CheckoutView from "./CheckoutView";
|
||||
@@ -25,9 +25,6 @@ import BillingDetailsView from "./BillingDetailsView";
|
||||
import LicenseActivationCard from "./LicenseActivationCard";
|
||||
import "./billing.css";
|
||||
|
||||
// sessionStorage key: value is a unix-ms expiry timestamp
|
||||
const BILLING_ACTIVATING_KEY = "billing_license_activating_until";
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Types
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -109,7 +106,6 @@ export default function BillingPage() {
|
||||
const [transitionType, setTransitionType] = useState<
|
||||
"expand" | "collapse" | "fade"
|
||||
>("fade");
|
||||
const [isActivating, setIsActivating] = useState<boolean>(false);
|
||||
|
||||
const {
|
||||
data: billingData,
|
||||
@@ -160,17 +156,6 @@ export default function BillingPage() {
|
||||
view,
|
||||
]);
|
||||
|
||||
// Read activating state from sessionStorage after mount (avoids SSR hydration mismatch)
|
||||
useEffect(() => {
|
||||
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
|
||||
if (!raw) return;
|
||||
if (Number(raw) > Date.now()) {
|
||||
setIsActivating(true);
|
||||
} else {
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Show license activation card when there's a Stripe error
|
||||
useEffect(() => {
|
||||
if (hasStripeError && !showLicenseActivationInput) {
|
||||
@@ -188,96 +173,24 @@ export default function BillingPage() {
|
||||
|
||||
router.replace("/admin/billing", { scroll: false });
|
||||
|
||||
let cancelled = false;
|
||||
|
||||
const handleBillingReturn = async () => {
|
||||
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
// Retry up to 3 times with 2s backoff. The license may not be available
|
||||
// immediately if the Stripe webhook hasn't finished processing yet
|
||||
// (redirect and webhook fire nearly simultaneously).
|
||||
let lastError: Error | null = null;
|
||||
for (let attempt = 0; attempt < 3; attempt++) {
|
||||
if (cancelled) return;
|
||||
try {
|
||||
// After checkout, exchange session_id for license; after portal, re-sync license
|
||||
await claimLicense(sessionId ?? undefined);
|
||||
if (cancelled) return;
|
||||
refreshLicense();
|
||||
// Refresh the page to update settings (including ee_features_enabled)
|
||||
router.refresh();
|
||||
// Navigate to billing details now that the license is active
|
||||
changeView("details");
|
||||
lastError = null;
|
||||
break;
|
||||
} catch (err) {
|
||||
lastError = err instanceof Error ? err : new Error("Unknown error");
|
||||
if (attempt < 2) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cancelled) return;
|
||||
if (lastError) {
|
||||
console.error(
|
||||
"Failed to sync license after billing return:",
|
||||
lastError
|
||||
);
|
||||
// Show an activating banner on the plans view and keep retrying in the background.
|
||||
sessionStorage.setItem(
|
||||
BILLING_ACTIVATING_KEY,
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
setIsActivating(true);
|
||||
changeView("plans");
|
||||
try {
|
||||
// After checkout, exchange session_id for license; after portal, re-sync license
|
||||
await claimLicense(sessionId ?? undefined);
|
||||
refreshLicense();
|
||||
// Refresh the page to update settings (including ee_features_enabled)
|
||||
router.refresh();
|
||||
// Navigate to billing details now that the license is active
|
||||
changeView("details");
|
||||
} catch (error) {
|
||||
console.error("Failed to sync license after billing return:", error);
|
||||
}
|
||||
}
|
||||
if (!cancelled) refreshBilling();
|
||||
refreshBilling();
|
||||
};
|
||||
handleBillingReturn();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
// changeView intentionally omitted: it only calls stable state setters and the
|
||||
// effect runs at most once (when session_id/portal_return params are present).
|
||||
}, [searchParams, router, refreshBilling, refreshLicense]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
// Poll every 15s while activating, up to 2 minutes, to detect when the license arrives.
|
||||
useEffect(() => {
|
||||
if (!isActivating) return;
|
||||
|
||||
let requestInFlight = false;
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
if (requestInFlight) return;
|
||||
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
|
||||
if (!raw || Number(raw) <= Date.now()) {
|
||||
// Expired — stop immediately without waiting for React cleanup
|
||||
clearInterval(intervalId);
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
return;
|
||||
}
|
||||
requestInFlight = true;
|
||||
try {
|
||||
await claimLicense(undefined);
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
refreshLicense();
|
||||
refreshBilling();
|
||||
router.refresh();
|
||||
changeView("details");
|
||||
} catch (err) {
|
||||
// License not ready yet — keep polling. Log so unexpected failures
|
||||
// (network errors, 500s) are distinguishable from expected 404s.
|
||||
console.debug("License activation poll: will retry", err);
|
||||
} finally {
|
||||
requestInFlight = false;
|
||||
}
|
||||
}, 15_000);
|
||||
|
||||
return () => clearInterval(intervalId);
|
||||
}, [isActivating]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
}, [searchParams, router, refreshBilling, refreshLicense]);
|
||||
|
||||
const handleRefresh = async () => {
|
||||
await Promise.all([
|
||||
@@ -474,22 +387,6 @@ export default function BillingPage() {
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<div className="flex flex-col items-center gap-6">
|
||||
{isActivating && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
large
|
||||
text="Your license is still activating"
|
||||
description="Your license is being processed. You'll be taken to billing details automatically once confirmed."
|
||||
icon
|
||||
close
|
||||
onClose={() => {
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
}}
|
||||
className="w-full"
|
||||
/>
|
||||
)}
|
||||
{renderContent()}
|
||||
{renderFooter()}
|
||||
</div>
|
||||
|
||||
@@ -7,6 +7,7 @@ import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import useSWR from "swr";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgLock } from "@opal/icons";
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import { useState, useMemo } from "react";
|
||||
import useSWR from "swr";
|
||||
import { Text } from "@opal/components";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
@@ -17,17 +17,9 @@ import {
|
||||
ImageGenerationConfigView,
|
||||
setDefaultImageGenerationConfig,
|
||||
unsetDefaultImageGenerationConfig,
|
||||
deleteImageGenerationConfig,
|
||||
} from "@/lib/configuration/imageConfigurationService";
|
||||
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
export default function ImageGenerationContent() {
|
||||
const {
|
||||
@@ -55,11 +47,6 @@ export default function ImageGenerationContent() {
|
||||
);
|
||||
const [editConfig, setEditConfig] =
|
||||
useState<ImageGenerationConfigView | null>(null);
|
||||
const [disconnectProvider, setDisconnectProvider] =
|
||||
useState<ImageProvider | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
|
||||
const connectedProviderIds = useMemo(() => {
|
||||
return new Set(configs.map((c) => c.image_provider_id));
|
||||
@@ -128,29 +115,6 @@ export default function ImageGenerationContent() {
|
||||
modal.toggle(true);
|
||||
};
|
||||
|
||||
const handleDisconnect = async () => {
|
||||
if (!disconnectProvider) return;
|
||||
try {
|
||||
// If a replacement was selected (not "No Default"), activate it first
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
await setDefaultImageGenerationConfig(replacementProviderId);
|
||||
}
|
||||
|
||||
await deleteImageGenerationConfig(disconnectProvider.image_provider_id);
|
||||
toast.success(`${disconnectProvider.title} disconnected`);
|
||||
refetchConfigs();
|
||||
refetchProviders();
|
||||
} catch (error) {
|
||||
console.error("Failed to disconnect image generation provider:", error);
|
||||
toast.error(
|
||||
error instanceof Error ? error.message : "Failed to disconnect"
|
||||
);
|
||||
} finally {
|
||||
setDisconnectProvider(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleModalSuccess = () => {
|
||||
toast.success("Provider configured successfully");
|
||||
setEditConfig(null);
|
||||
@@ -166,44 +130,12 @@ export default function ImageGenerationContent() {
|
||||
);
|
||||
}
|
||||
|
||||
// Compute replacement options when disconnecting an active provider
|
||||
const isDisconnectingDefault =
|
||||
disconnectProvider &&
|
||||
defaultConfig?.image_provider_id === disconnectProvider.image_provider_id;
|
||||
|
||||
// Group connected replacement models by provider (excluding the model being disconnected)
|
||||
const replacementGroups = useMemo(() => {
|
||||
if (!disconnectProvider) return [];
|
||||
return IMAGE_PROVIDER_GROUPS.map((group) => ({
|
||||
...group,
|
||||
providers: group.providers.filter(
|
||||
(p) =>
|
||||
p.image_provider_id !== disconnectProvider.image_provider_id &&
|
||||
connectedProviderIds.has(p.image_provider_id)
|
||||
),
|
||||
})).filter((g) => g.providers.length > 0);
|
||||
}, [disconnectProvider, connectedProviderIds]);
|
||||
|
||||
const needsReplacement = !!isDisconnectingDefault;
|
||||
const hasReplacements = replacementGroups.length > 0;
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && !replacementProviderId && hasReplacements) {
|
||||
const firstGroup = replacementGroups[0];
|
||||
const firstModel = firstGroup?.providers[0];
|
||||
if (firstModel) setReplacementProviderId(firstModel.image_provider_id);
|
||||
}
|
||||
}, [disconnectProvider]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Section Header */}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text font="main-content-emphasis" color="text-05">
|
||||
Image Generation Model
|
||||
</Text>
|
||||
<Text font="main-content-emphasis">Image Generation Model</Text>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Select a model to generate images in chat.
|
||||
</Text>
|
||||
@@ -241,11 +173,6 @@ export default function ImageGenerationContent() {
|
||||
onSelect={() => handleSelect(provider)}
|
||||
onDeselect={() => handleDeselect(provider)}
|
||||
onEdit={() => handleEdit(provider)}
|
||||
onDisconnect={
|
||||
getStatus(provider) !== "disconnected"
|
||||
? () => setDisconnectProvider(provider)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@@ -253,108 +180,6 @@ export default function ImageGenerationContent() {
|
||||
))}
|
||||
</div>
|
||||
|
||||
{disconnectProvider && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectProvider.title}`}
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={() => {
|
||||
setDisconnectProvider(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={() => void handleDisconnect()}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** is currently the default image generation model. Session history will be preserved.`
|
||||
)}
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" color="text-04">
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => setReplacementProviderId(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement model" />
|
||||
<InputSelect.Content>
|
||||
{replacementGroups.map((group) => (
|
||||
<InputSelect.Group key={group.name}>
|
||||
<InputSelect.Label>{group.name}</InputSelect.Label>
|
||||
{group.providers.map((p) => (
|
||||
<InputSelect.Item
|
||||
key={p.image_provider_id}
|
||||
value={p.image_provider_id}
|
||||
icon={() => (
|
||||
<ProviderIcon
|
||||
provider={p.provider_name}
|
||||
size={16}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
{p.title}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Group>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item
|
||||
value={NO_DEFAULT_VALUE}
|
||||
icon={SvgSlash}
|
||||
>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03">
|
||||
{" "}
|
||||
(Disable Image Generation)
|
||||
</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** is currently the default image generation model.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" color="text-03">
|
||||
Connect another provider to continue using image generation.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** models will no longer be used to generate images.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" color="text-03">
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
|
||||
{activeProvider && (
|
||||
<modal.Provider>
|
||||
<ImageGenerationConnectionModal
|
||||
|
||||
@@ -8,6 +8,7 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgX } from "@opal/icons";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
function ModelConfigurationRow({
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";
|
||||
"use client";
|
||||
|
||||
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
export default function Page() {
|
||||
return <LLMConfigurationPage />;
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import {
|
||||
BedrockModelResponse,
|
||||
LMStudioModelResponse,
|
||||
LiteLLMProxyModelResponse,
|
||||
BifrostModelResponse,
|
||||
ModelConfiguration,
|
||||
LLMProviderName,
|
||||
BedrockFetchParams,
|
||||
@@ -31,11 +30,8 @@ import {
|
||||
LMStudioFetchParams,
|
||||
OpenRouterFetchParams,
|
||||
LiteLLMProxyFetchParams,
|
||||
BifrostFetchParams,
|
||||
OpenAICompatibleFetchParams,
|
||||
OpenAICompatibleModelResponse,
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter, SvgPlug } from "@opal/icons";
|
||||
import { SvgAws, SvgOpenrouter } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
@@ -45,8 +41,6 @@ export const AGGREGATOR_PROVIDERS = new Set([
|
||||
"ollama_chat",
|
||||
"lm_studio",
|
||||
"litellm_proxy",
|
||||
"bifrost",
|
||||
"openai_compatible",
|
||||
"vertex_ai",
|
||||
]);
|
||||
|
||||
@@ -84,8 +78,6 @@ export const getProviderIcon = (
|
||||
bedrock_converse: SvgAws,
|
||||
openrouter: SvgOpenrouter,
|
||||
litellm_proxy: LiteLLMIcon,
|
||||
bifrost: SvgBifrost,
|
||||
openai_compatible: SvgPlug,
|
||||
vertex_ai: GeminiIcon,
|
||||
};
|
||||
|
||||
@@ -271,11 +263,8 @@ export const fetchOpenRouterModels = async (
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse OpenRouter model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
@@ -325,125 +314,6 @@ export const fetchLMStudioModels = async (
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse LM Studio model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: LMStudioModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: modelData.supports_reasoning,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches Bifrost models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchBifrostModels = async (
|
||||
params: BifrostFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/admin/llm/bifrost/available-models", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: params.api_key,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse Bifrost model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: BifrostModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: modelData.supports_reasoning,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches models from a generic OpenAI-compatible server.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchOpenAICompatibleModels = async (
|
||||
params: OpenAICompatibleFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/admin/llm/openai-compatible/available-models",
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: params.api_key,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
@@ -455,7 +325,7 @@ export const fetchOpenAICompatibleModels = async (
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: OpenAICompatibleModelResponse[] = await response.json();
|
||||
const data: LMStudioModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
@@ -586,20 +456,6 @@ export const fetchModels = async (
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
case LLMProviderName.BIFROST:
|
||||
return fetchBifrostModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return fetchOpenAICompatibleModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
default:
|
||||
return { models: [], error: `Unknown provider: ${providerName}` };
|
||||
}
|
||||
@@ -613,8 +469,6 @@ export function canProviderFetchModels(providerName?: string) {
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
case LLMProviderName.OPENROUTER:
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
case LLMProviderName.BIFROST:
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -1,25 +1,33 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { useEffect, useMemo, useState, useReducer } from "react";
|
||||
import { useMemo, useState, useReducer } from "react";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgCheckSquare,
|
||||
SvgEdit,
|
||||
SvgGlobe,
|
||||
SvgOnyxLogo,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
|
||||
const route = ADMIN_ROUTES.WEB_SEARCH;
|
||||
import {
|
||||
SEARCH_PROVIDERS_URL,
|
||||
SEARCH_PROVIDER_DETAILS,
|
||||
@@ -51,10 +59,6 @@ import {
|
||||
} from "@/app/admin/configuration/web-search/WebProviderModalReducer";
|
||||
import { connectProviderFlow } from "@/app/admin/configuration/web-search/connectProviderFlow";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
const route = ADMIN_ROUTES.WEB_SEARCH;
|
||||
|
||||
interface WebSearchProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
@@ -73,151 +77,27 @@ interface WebContentProviderView {
|
||||
has_api_key: boolean;
|
||||
}
|
||||
|
||||
interface DisconnectTargetState {
|
||||
id: number;
|
||||
label: string;
|
||||
category: "search" | "content";
|
||||
providerType: string;
|
||||
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
|
||||
isHovered: boolean;
|
||||
onMouseEnter: () => void;
|
||||
onMouseLeave: () => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
function WebSearchDisconnectModal({
|
||||
disconnectTarget,
|
||||
searchProviders,
|
||||
contentProviders,
|
||||
replacementProviderId,
|
||||
onReplacementChange,
|
||||
onClose,
|
||||
onDisconnect,
|
||||
}: {
|
||||
disconnectTarget: DisconnectTargetState;
|
||||
searchProviders: WebSearchProviderView[];
|
||||
contentProviders: WebContentProviderView[];
|
||||
replacementProviderId: string | null;
|
||||
onReplacementChange: (id: string | null) => void;
|
||||
onClose: () => void;
|
||||
onDisconnect: () => void;
|
||||
}) {
|
||||
const isSearch = disconnectTarget.category === "search";
|
||||
|
||||
// Determine if the target is currently the active/selected provider
|
||||
const isActive = isSearch
|
||||
? searchProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
|
||||
false
|
||||
: contentProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
|
||||
false;
|
||||
|
||||
// Find other configured providers as replacements
|
||||
const replacementOptions = isSearch
|
||||
? searchProviders.filter(
|
||||
(p) => p.id !== disconnectTarget.id && p.id > 0 && p.has_api_key
|
||||
)
|
||||
: contentProviders.filter(
|
||||
(p) =>
|
||||
p.id !== disconnectTarget.id &&
|
||||
p.provider_type !== "onyx_web_crawler" &&
|
||||
p.id > 0 &&
|
||||
p.has_api_key
|
||||
);
|
||||
|
||||
const needsReplacement = isActive;
|
||||
const hasReplacements = replacementOptions.length > 0;
|
||||
|
||||
const getLabel = (p: { name: string; provider_type: string }) => {
|
||||
if (isSearch) {
|
||||
const details =
|
||||
SEARCH_PROVIDER_DETAILS[p.provider_type as WebSearchProviderType];
|
||||
return details?.label ?? p.name ?? p.provider_type;
|
||||
}
|
||||
const details = CONTENT_PROVIDER_DETAILS[p.provider_type];
|
||||
return details?.label ?? p.name ?? p.provider_type;
|
||||
};
|
||||
|
||||
const categoryLabel = isSearch ? "search engine" : "web crawler";
|
||||
const featureLabel = isSearch ? "web search" : "web crawling";
|
||||
const disableLabel = isSearch ? "Disable Web Search" : "Disable Web Crawling";
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && hasReplacements && !replacementProviderId) {
|
||||
const first = replacementOptions[0];
|
||||
if (first) onReplacementChange(String(first.id));
|
||||
}
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
function HoverIconButton({
|
||||
isHovered,
|
||||
onMouseEnter,
|
||||
onMouseLeave,
|
||||
children,
|
||||
...buttonProps
|
||||
}: HoverIconButtonProps) {
|
||||
return (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectTarget.label}`}
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.label}</b> is currently the active{" "}
|
||||
{categoryLabel}. Search history will be preserved.
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => onReplacementChange(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement provider" />
|
||||
<InputSelect.Content>
|
||||
{replacementOptions.map((p) => (
|
||||
<InputSelect.Item key={p.id} value={String(p.id)}>
|
||||
{getLabel(p)}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03"> ({disableLabel})</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.label}</b> is currently the active{" "}
|
||||
{categoryLabel}.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Connect another provider to continue using {featureLabel}.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
{isSearch ? "Web search" : "Web crawling"} will no longer be routed
|
||||
through <b>{disconnectTarget.label}</b>.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Search history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
{/* TODO(@raunakab): migrate to opal Button once HoverIconButtonProps typing is resolved */}
|
||||
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
|
||||
{children}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -226,11 +106,6 @@ export default function Page() {
|
||||
WebProviderModalReducer,
|
||||
initialWebProviderModalState
|
||||
);
|
||||
const [disconnectTarget, setDisconnectTarget] =
|
||||
useState<DisconnectTargetState | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const [contentModal, dispatchContentModal] = useReducer(
|
||||
WebProviderModalReducer,
|
||||
initialWebProviderModalState
|
||||
@@ -239,6 +114,8 @@ export default function Page() {
|
||||
const [contentActivationError, setContentActivationError] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
data: searchProvidersData,
|
||||
error: searchProvidersError,
|
||||
@@ -957,67 +834,6 @@ export default function Page() {
|
||||
});
|
||||
};
|
||||
|
||||
const handleDisconnectProvider = async () => {
|
||||
if (!disconnectTarget) return;
|
||||
const { id, category } = disconnectTarget;
|
||||
|
||||
try {
|
||||
// If a replacement was selected (not "No Default"), activate it first
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
const repId = Number(replacementProviderId);
|
||||
const activateEndpoint =
|
||||
category === "search"
|
||||
? `/api/admin/web-search/search-providers/${repId}/activate`
|
||||
: `/api/admin/web-search/content-providers/${repId}/activate`;
|
||||
const activateResp = await fetch(activateEndpoint, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
if (!activateResp.ok) {
|
||||
const errorBody = await activateResp.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to activate replacement provider."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`/api/admin/web-search/${category}-providers/${id}`,
|
||||
{ method: "DELETE" }
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch((parseErr) => {
|
||||
console.error("Failed to parse disconnect error response:", parseErr);
|
||||
return {};
|
||||
});
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to disconnect provider."
|
||||
);
|
||||
}
|
||||
|
||||
toast.success(`${disconnectTarget.label} disconnected`);
|
||||
await mutateSearchProviders();
|
||||
await mutateContentProviders();
|
||||
} catch (error) {
|
||||
console.error("Failed to disconnect web search provider:", error);
|
||||
const message =
|
||||
error instanceof Error ? error.message : "Unexpected error occurred.";
|
||||
if (category === "search") {
|
||||
setActivationError(message);
|
||||
} else {
|
||||
setContentActivationError(message);
|
||||
}
|
||||
} finally {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingsLayouts.Root>
|
||||
@@ -1079,79 +895,149 @@ export default function Page() {
|
||||
provider
|
||||
);
|
||||
const isActive = provider?.is_active ?? false;
|
||||
const isHighlighted = isActive;
|
||||
const providerId = provider?.id;
|
||||
const canOpenModal =
|
||||
isBuiltInSearchProviderType(providerType);
|
||||
|
||||
const status: "disconnected" | "connected" | "selected" =
|
||||
!isConfigured
|
||||
? "disconnected"
|
||||
: isActive
|
||||
? "selected"
|
||||
: "connected";
|
||||
|
||||
return (
|
||||
<Select
|
||||
key={`${key}-${providerType}`}
|
||||
icon={() =>
|
||||
logoSrc ? (
|
||||
<Image
|
||||
src={logoSrc}
|
||||
alt={`${label} logo`}
|
||||
width={16}
|
||||
height={16}
|
||||
/>
|
||||
) : (
|
||||
<SvgGlobe size={16} />
|
||||
)
|
||||
}
|
||||
title={label}
|
||||
description={subtitle}
|
||||
status={status}
|
||||
onConnect={
|
||||
canOpenModal
|
||||
const buttonState = (() => {
|
||||
if (!provider || !isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
disabled: false,
|
||||
icon: "arrow" as const,
|
||||
onClick: canOpenModal
|
||||
? () => {
|
||||
openSearchModal(providerType, provider);
|
||||
setActivationError(null);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onSelect={
|
||||
providerId
|
||||
? () => {
|
||||
void handleActivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onDeselect={
|
||||
providerId
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
label: "Current Default",
|
||||
disabled: false,
|
||||
icon: "check" as const,
|
||||
onClick: providerId
|
||||
? () => {
|
||||
void handleDeactivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onEdit={
|
||||
isConfigured && canOpenModal
|
||||
? () => {
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
disabled: false,
|
||||
icon: "arrow-circle" as const,
|
||||
onClick: providerId
|
||||
? () => {
|
||||
void handleActivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const buttonKey = `search-${key}-${providerType}`;
|
||||
const isButtonHovered = hoveredButtonKey === buttonKey;
|
||||
const isCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (isCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${key}-${providerType}`}
|
||||
onClick={isCardClickable ? handleCardClick : undefined}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
|
||||
isHighlighted
|
||||
? "border-action-link-05"
|
||||
: "border-border-01",
|
||||
isCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 px-2 py-1">
|
||||
{renderLogo({
|
||||
logoSrc,
|
||||
alt: `${label} logo`,
|
||||
size: 16,
|
||||
isHighlighted,
|
||||
})}
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
if (!canOpenModal) return;
|
||||
openSearchModal(
|
||||
providerType as WebSearchProviderType,
|
||||
provider
|
||||
);
|
||||
}}
|
||||
aria-label={`Edit ${label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isButtonHovered}
|
||||
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Disabled
|
||||
disabled={
|
||||
buttonState.disabled || !buttonState.onClick
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onDisconnect={
|
||||
isConfigured && provider && provider.id > 0
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
id: provider.id,
|
||||
label,
|
||||
category: "search",
|
||||
providerType,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
)}
|
||||
@@ -1191,81 +1077,161 @@ export default function Page() {
|
||||
const isCurrentCrawler =
|
||||
provider.provider_type === currentContentProviderType;
|
||||
|
||||
const status: "disconnected" | "connected" | "selected" =
|
||||
!isConfigured
|
||||
? "disconnected"
|
||||
: isCurrentCrawler
|
||||
? "selected"
|
||||
: "connected";
|
||||
const buttonState = (() => {
|
||||
if (!isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
icon: "arrow" as const,
|
||||
disabled: false,
|
||||
onClick: () => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
setContentActivationError(null);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
if (isCurrentCrawler) {
|
||||
return {
|
||||
label: "Current Crawler",
|
||||
icon: "check" as const,
|
||||
disabled: false,
|
||||
onClick: () => {
|
||||
void handleDeactivateContentProvider(
|
||||
providerId,
|
||||
provider.provider_type
|
||||
);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const contentLogoSrc =
|
||||
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.logoSrc;
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
icon: "arrow-circle" as const,
|
||||
disabled: !canActivate,
|
||||
onClick: canActivate
|
||||
? () => {
|
||||
void handleActivateContentProvider(provider);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const contentButtonKey = `content-${provider.provider_type}-${provider.id}`;
|
||||
const isContentButtonHovered =
|
||||
hoveredButtonKey === contentButtonKey;
|
||||
const isContentCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleContentCardClick = () => {
|
||||
if (isContentCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Select
|
||||
<div
|
||||
key={`${provider.provider_type}-${provider.id}`}
|
||||
icon={() =>
|
||||
contentLogoSrc ? (
|
||||
<Image
|
||||
src={contentLogoSrc}
|
||||
alt={`${label} logo`}
|
||||
width={16}
|
||||
height={16}
|
||||
/>
|
||||
) : provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo size={16} />
|
||||
onClick={
|
||||
isContentCardClickable
|
||||
? handleContentCardClick
|
||||
: undefined
|
||||
}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
|
||||
isCurrentCrawler
|
||||
? "border-action-link-05"
|
||||
: "border-border-01",
|
||||
isContentCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 px-2 py-1">
|
||||
{renderLogo({
|
||||
logoSrc:
|
||||
CONTENT_PROVIDER_DETAILS[provider.provider_type]
|
||||
?.logoSrc,
|
||||
alt: `${label} logo`,
|
||||
fallback:
|
||||
provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : undefined,
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
})}
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
openContentModal(
|
||||
provider.provider_type,
|
||||
provider
|
||||
);
|
||||
}}
|
||||
aria-label={`Edit ${label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isContentButtonHovered}
|
||||
onMouseEnter={() =>
|
||||
setHoveredButtonKey(contentButtonKey)
|
||||
}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<SvgGlobe size={16} />
|
||||
)
|
||||
}
|
||||
title={label}
|
||||
description={subtitle}
|
||||
status={status}
|
||||
selectedLabel="Current Crawler"
|
||||
onConnect={() => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
setContentActivationError(null);
|
||||
}}
|
||||
onSelect={
|
||||
canActivate
|
||||
? () => {
|
||||
void handleActivateContentProvider(provider);
|
||||
<Disabled
|
||||
disabled={
|
||||
buttonState.disabled || !buttonState.onClick
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onDeselect={() => {
|
||||
void handleDeactivateContentProvider(
|
||||
providerId,
|
||||
provider.provider_type
|
||||
);
|
||||
}}
|
||||
onEdit={
|
||||
provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured
|
||||
? () => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onDisconnect={
|
||||
provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured &&
|
||||
provider.id > 0
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
id: provider.id,
|
||||
label,
|
||||
category: "content",
|
||||
providerType: provider.provider_type,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
@@ -1273,21 +1239,6 @@ export default function Page() {
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
|
||||
{disconnectTarget && (
|
||||
<WebSearchDisconnectModal
|
||||
disconnectTarget={disconnectTarget}
|
||||
searchProviders={searchProviders}
|
||||
contentProviders={combinedContentProviders}
|
||||
replacementProviderId={replacementProviderId}
|
||||
onReplacementChange={setReplacementProviderId}
|
||||
onClose={() => {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
onDisconnect={() => void handleDisconnectProvider()}
|
||||
/>
|
||||
)}
|
||||
|
||||
<WebProviderSetupModal
|
||||
isOpen={selectedProviderType !== null}
|
||||
onClose={() => {
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useState } from "react";
|
||||
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button } from "@opal/components";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
import { IndexAttemptError } from "./types";
|
||||
import { localizeAndPrettify } from "@/lib/time";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Text } from "@opal/components";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
import { useCallback, useEffect, useRef, useState, useMemo } from "react";
|
||||
import { SvgAlertTriangle } from "@opal/icons";
|
||||
@@ -113,11 +113,11 @@ export default function IndexAttemptErrorsModal({
|
||||
<Modal.Body height="full">
|
||||
{!isResolvingErrors && (
|
||||
<div className="flex flex-col gap-2 flex-shrink-0">
|
||||
<Text as="p">
|
||||
<Text as="p" color="text-05">
|
||||
Below are the errors encountered during indexing. Each row
|
||||
represents a failed document or entity.
|
||||
</Text>
|
||||
<Text as="p">
|
||||
<Text as="p" color="text-05">
|
||||
Click the button below to kick off a full re-index to try and
|
||||
resolve these errors. This full re-index may take much longer
|
||||
than a normal update.
|
||||
|
||||
@@ -21,6 +21,7 @@ import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
SvgCheck,
|
||||
|
||||
@@ -6,6 +6,7 @@ import { useState } from "react";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { triggerIndexing } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { SvgRefreshCw } from "@opal/icons";
|
||||
|
||||
@@ -7,6 +7,7 @@ import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { CCPairStatus, PermissionSyncStatus } from "@/components/Status";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import CredentialSection from "@/components/credentials/CredentialSection";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
updateConnectorCredentialPairName,
|
||||
|
||||
@@ -59,6 +59,7 @@ import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { deleteConnector } from "@/lib/connector";
|
||||
import ConnectorDocsLink from "@/components/admin/connectors/ConnectorDocsLink";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgKey, SvgAlertCircle } from "@opal/icons";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
|
||||
@@ -19,7 +19,6 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import { Credential } from "@/lib/connectors/credentials";
|
||||
import { useFederatedConnectors } from "@/lib/hooks";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useToastFromQuery } from "@/hooks/useToast";
|
||||
|
||||
export default function ConnectorWrapper({
|
||||
|
||||
@@ -15,7 +15,7 @@ import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import CheckboxField from "@/refresh-components/form/LabeledCheckboxField";
|
||||
import InputTextAreaField from "@/refresh-components/form/InputTextAreaField";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Text } from "@opal/components";
|
||||
|
||||
// Define a general type for form values
|
||||
type FormValues = Record<string, any>;
|
||||
@@ -57,7 +57,7 @@ const TabsField: FC<TabsFieldProps> = ({
|
||||
|
||||
{/* Ensure there's at least one tab before rendering */}
|
||||
{tabField.tabs.length === 0 ? (
|
||||
<Text text03 secondaryBody>
|
||||
<Text color="text-03" font="secondary-body">
|
||||
No tabs to display.
|
||||
</Text>
|
||||
) : (
|
||||
@@ -253,7 +253,7 @@ export const RenderField: FC<RenderFieldProps> = ({
|
||||
)
|
||||
) : field.type === "string_tab" ? (
|
||||
<GeneralLayouts.Section>
|
||||
<Text text03 secondaryBody>
|
||||
<Text color="text-03" font="secondary-body">
|
||||
{description}
|
||||
</Text>
|
||||
</GeneralLayouts.Section>
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import { useState } from "react";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import { Button } from "@opal/components";
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
import Switch from "@/refresh-components/inputs/Switch";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import EmptyMessage from "@/refresh-components/EmptyMessage";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Text } from "@opal/components";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import {
|
||||
DiscordChannelConfig,
|
||||
@@ -95,9 +95,7 @@ export function DiscordChannelsTable({
|
||||
width="fit"
|
||||
>
|
||||
<ChannelIcon width={16} height={16} />
|
||||
<Text text04 mainUiBody>
|
||||
{channel.channel_name}
|
||||
</Text>
|
||||
<Text>{channel.channel_name}</Text>
|
||||
</Section>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
|
||||
@@ -8,11 +8,10 @@ import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgServer } from "@opal/icons";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
@@ -121,7 +120,7 @@ function GuildDetailContent({
|
||||
/>
|
||||
|
||||
{!isRegistered ? (
|
||||
<Text text03 secondaryBody>
|
||||
<Text color="text-03" font="secondary-body">
|
||||
Channel configuration will be available after the server is
|
||||
registered.
|
||||
</Text>
|
||||
|
||||
@@ -6,7 +6,7 @@ import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Text } from "@opal/components";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
@@ -76,7 +76,7 @@ function DiscordBotContent() {
|
||||
description="This key will only be shown once!"
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Text text04 mainUiBody>
|
||||
<Text>
|
||||
Copy the command and send it from any text channel in your server!
|
||||
</Text>
|
||||
<Card variant="secondary">
|
||||
@@ -85,8 +85,8 @@ function DiscordBotContent() {
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
>
|
||||
<Text text03 secondaryMono>
|
||||
!register {registrationKey}
|
||||
<Text color="text-03" font="secondary-mono">
|
||||
{`!register ${registrationKey}`}
|
||||
</Text>
|
||||
<CopyIconButton
|
||||
getCopyText={() => `!register ${registrationKey}`}
|
||||
@@ -103,7 +103,7 @@ function DiscordBotContent() {
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
>
|
||||
<Text mainContentEmphasis text05>
|
||||
<Text font="main-content-emphasis" color="text-05">
|
||||
Server Configurations
|
||||
</Text>
|
||||
<CreateButton
|
||||
|
||||
@@ -9,7 +9,7 @@ const route = ADMIN_ROUTES.INDEX_MIGRATION;
|
||||
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import { Content, ContentAction } from "@opal/layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Text } from "@opal/components";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
@@ -38,10 +38,10 @@ function MigrationStatusSection() {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Card>
|
||||
<Text headingH3>Migration Status</Text>
|
||||
<Text mainUiBody text03>
|
||||
Loading...
|
||||
<Text font="heading-h3" color="text-05">
|
||||
Migration Status
|
||||
</Text>
|
||||
<Text color="text-03">Loading...</Text>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -49,10 +49,10 @@ function MigrationStatusSection() {
|
||||
if (error) {
|
||||
return (
|
||||
<Card>
|
||||
<Text headingH3>Migration Status</Text>
|
||||
<Text mainUiBody text03>
|
||||
Failed to load migration status.
|
||||
<Text font="heading-h3" color="text-05">
|
||||
Migration Status
|
||||
</Text>
|
||||
<Text color="text-03">Failed to load migration status.</Text>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -73,14 +73,16 @@ function MigrationStatusSection() {
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<Text headingH3>Migration Status</Text>
|
||||
<Text font="heading-h3" color="text-05">
|
||||
Migration Status
|
||||
</Text>
|
||||
|
||||
<ContentAction
|
||||
title="Started"
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
<Text mainUiBody>
|
||||
<Text color="text-05">
|
||||
{hasStarted ? formatTimestamp(data.created_at!) : "Not started"}
|
||||
</Text>
|
||||
}
|
||||
@@ -91,7 +93,7 @@ function MigrationStatusSection() {
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
<Text mainUiBody>
|
||||
<Text color="text-05">
|
||||
{progressPercentage !== null
|
||||
? `${totalChunksMigrated} (approx. progress ${Math.round(
|
||||
progressPercentage
|
||||
@@ -106,7 +108,7 @@ function MigrationStatusSection() {
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
<Text mainUiBody>
|
||||
<Text color="text-05">
|
||||
{hasCompleted
|
||||
? formatTimestamp(data.migration_completed_at!)
|
||||
: hasStarted
|
||||
@@ -159,10 +161,10 @@ function RetrievalSourceSection() {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Card>
|
||||
<Text headingH3>Retrieval Source</Text>
|
||||
<Text mainUiBody text03>
|
||||
Loading...
|
||||
<Text font="heading-h3" color="text-05">
|
||||
Retrieval Source
|
||||
</Text>
|
||||
<Text color="text-03">Loading...</Text>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -170,10 +172,10 @@ function RetrievalSourceSection() {
|
||||
if (error) {
|
||||
return (
|
||||
<Card>
|
||||
<Text headingH3>Retrieval Source</Text>
|
||||
<Text mainUiBody text03>
|
||||
Failed to load retrieval settings.
|
||||
<Text font="heading-h3" color="text-05">
|
||||
Retrieval Source
|
||||
</Text>
|
||||
<Text color="text-03">Failed to load retrieval settings.</Text>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import React, { useRef, useState } from "react";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button } from "@opal/components";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Button } from "@opal/components";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgAlertTriangle } from "@opal/icons";
|
||||
export interface InstantSwitchConfirmModalProps {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { Button } from "@opal/components";
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import React, { useRef, useState } from "react";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { Button } from "@opal/components";
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Button } from "@opal/components";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { CloudEmbeddingModel } from "@/components/embedding/interfaces";
|
||||
import { SvgServer } from "@opal/icons";
|
||||
|
||||
@@ -4,6 +4,7 @@ import { toast } from "@/hooks/useToast";
|
||||
|
||||
import EmbeddingModelSelection from "../EmbeddingModelSelectionForm";
|
||||
import { useCallback, useEffect, useMemo, useState, useRef } from "react";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
|
||||
@@ -8,6 +8,7 @@ import { ValidSources } from "@/lib/types";
|
||||
import { FaCircleQuestion } from "react-icons/fa6";
|
||||
import { CheckmarkIcon } from "@/components/icons/icons";
|
||||
import { Button } from "@opal/components";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
@@ -28,6 +28,7 @@ import Title from "@/components/ui/title";
|
||||
import { redirect } from "next/navigation";
|
||||
import { useIsKGExposed } from "@/app/admin/kg/utils";
|
||||
import KGEntityTypes from "@/app/admin/kg/KGEntityTypes";
|
||||
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user