Compare commits

..

3 Commits

Author SHA1 Message Date
Raunak Bhagat
c079cd867a refactor: migrate 57 more refresh-components/Text files to Opal Text (batch 2) 2026-03-26 11:21:42 -07:00
Raunak Bhagat
f1c1473903 refactor: migrate 29 refresh-components/Text files to Opal Text (batch 1)
Convert boolean-flag Text API to string-enum props. Files with JSX
children, conditional boolean props, or className left with TODOs.
2026-03-26 11:06:29 -07:00
Raunak Bhagat
4a67cc0a09 chore: add migration TODOs to all 235 refresh-components/Text imports 2026-03-26 10:34:29 -07:00
338 changed files with 3879 additions and 8107 deletions

View File

@@ -1,35 +0,0 @@
"""remove voice_provider deleted column
Revision ID: 1d78c0ca7853
Revises: a3f8b2c1d4e5
Create Date: 2026-03-26 11:30:53.883127
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "1d78c0ca7853"
down_revision = "a3f8b2c1d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Hard-delete any soft-deleted rows before dropping the column
op.execute("DELETE FROM voice_provider WHERE deleted = true")
op.drop_column("voice_provider", "deleted")
def downgrade() -> None:
op.add_column(
"voice_provider",
sa.Column(
"deleted",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)

View File

@@ -28,7 +28,6 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ElementExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -188,6 +187,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
@@ -227,7 +227,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_permission_sync_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)

View File

@@ -29,7 +29,6 @@ from ee.onyx.external_permissions.sync_params import (
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.error_logging import emit_background_error
@@ -163,6 +162,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
@@ -221,7 +221,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_external_group_sync_fences(
tenant_id, self.app, r, r_replica, r_celery, lock_beat
)

View File

@@ -11,8 +11,6 @@ require a valid SCIM bearer token.
from __future__ import annotations
import hashlib
import struct
from uuid import UUID
from fastapi import APIRouter
@@ -24,7 +22,6 @@ from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -62,25 +59,9 @@ from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Group names reserved for system default groups (seeded by migration).
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
# Namespace prefix for the seat-allocation advisory lock. Hashed together
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
# never block each other) and cannot collide with unrelated advisory locks.
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
return struct.unpack("q", digest[:8])[0]
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
@@ -219,37 +200,12 @@ def _apply_exclusions(
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None.
Acquires a transaction-scoped advisory lock so that concurrent
SCIM requests are serialized. IdPs like Okta send provisioning
requests in parallel batches — without serialization the check is
vulnerable to a TOCTOU race where N concurrent requests each see
"seats available", all insert, and the tenant ends up over its
seat limit.
The lock is held until the caller's next COMMIT or ROLLBACK, which
means the seat count cannot change between the check here and the
subsequent INSERT/UPDATE. Each call site in this module follows
the pattern: _check_seat_availability → write → dal.commit()
(which releases the lock for the next waiting request).
"""
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
# The lock id is derived from the tenant so unrelated tenants never block
# each other, and from a namespace string so it cannot collide with
# unrelated advisory locks elsewhere in the codebase.
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
dal.session.execute(
text("SELECT pg_advisory_xact_lock(:lock_id)"),
{"lock_id": lock_id},
)
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"

View File

@@ -1,6 +1,5 @@
# These are helper objects for tracking the keys we need to write in redis
import json
import threading
from typing import Any
from typing import cast
@@ -8,59 +7,7 @@ from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
_broker_client: Redis | None = None
_broker_url: str | None = None
_broker_client_lock = threading.Lock()
def celery_get_broker_client(app: Celery) -> Redis:
"""Return a shared Redis client connected to the Celery broker DB.
Uses a module-level singleton so all tasks on a worker share one
connection instead of creating a new one per call. The client
connects directly to the broker Redis DB (parsed from the broker URL).
Thread-safe via lock — safe for use in Celery thread-pool workers.
Usage:
r_celery = celery_get_broker_client(self.app)
length = celery_get_queue_length(queue, r_celery)
"""
global _broker_client, _broker_url
with _broker_client_lock:
url = app.conf.broker_url
if _broker_client is not None and _broker_url == url:
try:
_broker_client.ping()
return _broker_client
except Exception:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
elif _broker_client is not None:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
_broker_url = url
_broker_client = Redis.from_url(
url,
decode_responses=False,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
retry_on_timeout=True,
)
return _broker_client
def celery_get_unacked_length(r: Redis) -> int:

View File

@@ -14,7 +14,6 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
@@ -133,6 +132,7 @@ def revoke_tasks_blocking_deletion(
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -149,7 +149,6 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
r_celery = celery_get_broker_client(self.app)
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -450,7 +449,7 @@ def check_indexing_completion(
):
# Check if the task exists in the celery queue
# This handles the case where Redis dies after task creation but before task execution
redis_celery = celery_get_broker_client(task.app)
redis_celery = task.app.broker_connection().channel().client # type: ignore
task_exists = celery_find_task(
attempt.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,

View File

@@ -1,5 +1,6 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
@@ -18,7 +19,6 @@ from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -698,27 +698,31 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
return None
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client()
# Collect queue metrics with broker connection
r_celery = celery_get_broker_client(self.app)
queue_metrics = _collect_queue_metrics(r_celery)
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
lambda: _collect_queue_metrics(redis_celery),
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect remaining metrics (no broker connection needed)
# Collect and log each metric
with get_session_with_current_tenant() as db_session:
all_metrics: list[Metric] = queue_metrics
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
for metric in all_metrics:
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")
except SoftTimeLimitExceeded:
@@ -886,7 +890,7 @@ def monitor_celery_queues_helper(
) -> None:
"""A task to monitor all celery queue lengths."""
r_celery = celery_get_broker_client(task.app)
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
@@ -1076,7 +1080,7 @@ def cloud_monitor_celery_pidbox(
num_deleted = 0
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
r_celery = celery_get_broker_client(self.app)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")

View File

@@ -17,7 +17,6 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -204,6 +203,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -261,7 +261,6 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
except Exception:
task_logger.exception("Exception while validating pruning fences")

View File

@@ -16,7 +16,6 @@ from sqlalchemy.orm import Session
from onyx.access.access import build_access_for_user_files
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
@@ -106,7 +105,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery = celery_get_broker_client(celery_app)
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
@@ -239,7 +238,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = celery_get_broker_client(self.app)
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
@@ -592,7 +591,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
# --- Protection 1: queue depth backpressure ---
# NOTE: must use the broker's Redis client (not redis_client) because
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
r_celery = celery_get_broker_client(self.app)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
task_logger.warning(

View File

@@ -816,29 +816,6 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Maximum embedded images allowed in a single file. PDFs (and other formats)
# with thousands of embedded images can OOM the user-file-processing worker
# because every image is decoded with PIL and then sent to the vision LLM.
# Enforced both at upload time (rejects the file) and during extraction
# (defense-in-depth: caps the number of images materialized).
#
# Clamped to >= 0; a negative env value would turn upload validation into
# always-fail and extraction into always-stop, which is never desired. 0
# disables image extraction entirely, which is a valid (if aggressive) setting.
MAX_EMBEDDED_IMAGES_PER_FILE = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
)
# Maximum embedded images allowed across all files in a single upload batch.
# Protects against the scenario where a user uploads many files that each
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
# worker under concurrency or run up surprise latency/cost. Also clamped
# to >= 0.
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
)
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag

View File

@@ -12,11 +12,6 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
MASK_CREDENTIAL_CHAR = "\u2022"
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
SOURCE_TYPE = "source_type"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed

View File

@@ -10,7 +10,6 @@ from datetime import timedelta
from datetime import timezone
from typing import Any
import requests
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
@@ -240,53 +239,29 @@ def enhanced_search_ids(
)
def _bulk_fetch_request(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO: move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
# Prepare the payload according to Jira API v3 specification
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
# Only restrict fields if specified, might want to explicitly do this in the future
# to avoid reading unnecessary data
payload["fields"] = fields.split(",") if fields else ["*all"]
resp = jira_client._session.post(bulk_fetch_path, json=payload)
return resp.json()["issues"]
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
f"be decoded as JSON (response too large or truncated)."
)
raise
mid = len(issue_ids) // 2
logger.warning(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
raise e
return [
Issue(jira_client._options, jira_client._session, raw=issue)
for issue in raw_issues
for issue in response["issues"]
]

View File

@@ -1,4 +1,3 @@
from dataclasses import dataclass
from datetime import datetime
from typing import TypedDict
@@ -7,14 +6,6 @@ from pydantic import BaseModel
from onyx.onyxbot.slack.models import ChannelType
@dataclass(frozen=True)
class DirectThreadFetch:
"""Request to fetch a Slack thread directly by channel and timestamp."""
channel_id: str
thread_ts: str
class ChannelMetadata(TypedDict):
"""Type definition for cached channel metadata."""

View File

@@ -19,7 +19,6 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
@@ -50,6 +49,7 @@ from onyx.server.federated.models import FederatedConnectorDetail
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
logger = setup_logger()
@@ -58,6 +58,7 @@ HIGHLIGHT_END_CHAR = "\ue001"
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
@@ -420,94 +421,6 @@ class SlackQueryResult(BaseModel):
filtered_channels: list[str] # Channels filtered out during this query
def _fetch_thread_from_url(
thread_fetch: DirectThreadFetch,
access_token: str,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
"""Fetch a thread directly from a Slack URL via conversations.replies."""
channel_id = thread_fetch.channel_id
thread_ts = thread_fetch.thread_ts
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
ts=thread_ts,
)
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.warning(
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
)
return SlackQueryResult(messages=[], filtered_channels=[])
if not messages:
logger.warning(
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
)
return SlackQueryResult(messages=[], filtered_channels=[])
# Build thread text from all messages
thread_text = _build_thread_text(messages, access_token, None, slack_client)
# Get channel name from metadata cache or API
channel_name = "unknown"
if channel_metadata_dict and channel_id in channel_metadata_dict:
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
else:
try:
ch_response = slack_client.conversations_info(channel=channel_id)
ch_response.validate()
channel_info: dict[str, Any] = ch_response.get("channel", {})
channel_name = channel_info.get("name", "unknown")
except SlackApiError:
pass
# Build the SlackMessage
parent_msg = messages[0]
message_ts = parent_msg.get("ts", thread_ts)
username = parent_msg.get("user", "unknown_user")
parent_text = parent_msg.get("text", "")
snippet = (
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
).replace("\n", " ")
doc_time = datetime.fromtimestamp(float(message_ts))
decay_factor = DOC_TIME_DECAY
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
permalink = (
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
)
slack_message = SlackMessage(
document_id=f"{channel_id}_{message_ts}",
channel_id=channel_id,
message_id=message_ts,
thread_id=None, # Prevent double-enrichment in thread context fetch
link=permalink,
metadata={
"channel": channel_name,
"time": doc_time.isoformat(),
},
timestamp=doc_time,
recency_bias=recency_bias,
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
text=thread_text,
highlighted_texts=set(),
slack_score=100000.0, # High priority — user explicitly asked for this thread
)
logger.info(
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
)
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
def query_slack(
query_string: str,
access_token: str,
@@ -519,6 +432,7 @@ def query_slack(
available_channels: list[str] | None = None,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
# Check if query has channel override (user specified channels in query)
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
@@ -748,6 +662,7 @@ def _fetch_thread_context(
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# If not a thread, return original text as success
if thread_id is None:
@@ -780,37 +695,62 @@ def _fetch_thread_context(
if len(messages) <= 1:
return ThreadContextResult.success(message.text)
# Build thread text from thread starter + all replies
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
# Build thread text from thread starter + context window around matched message
thread_text = _build_thread_text(
messages, message_id, thread_id, access_token, team_id, slack_client
)
return ThreadContextResult.success(thread_text)
def _build_thread_text(
messages: list[dict[str, Any]],
message_id: str,
thread_id: str,
access_token: str,
team_id: str | None,
slack_client: WebClient,
) -> str:
"""Build thread text including all replies.
Includes the thread parent message followed by all replies in order.
"""
"""Build the thread text from messages."""
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# All messages after index 0 are replies
replies = messages[1:]
if not replies:
return thread_text
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
else:
message_id_idx = next(
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
)
if not message_id_idx:
return thread_text
for msg in replies:
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
if start_idx > 1:
thread_text += "\n..."
for i in range(start_idx, message_id_idx):
msg_text = messages[i].get("text", "")
msg_sender = messages[i].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add following replies
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
msg_text = msg.get("text", "")
msg_sender = msg.get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
# Replace user IDs with names using cached lookups
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
@@ -1036,16 +976,7 @@ def slack_retrieval(
# Query slack with entity filtering
llm = get_default_llm()
query_items = build_slack_queries(query, llm, entities, available_channels)
# Partition into direct thread fetches and search query strings
direct_fetches: list[DirectThreadFetch] = []
query_strings: list[str] = []
for item in query_items:
if isinstance(item, DirectThreadFetch):
direct_fetches.append(item)
else:
query_strings.append(item)
query_strings = build_slack_queries(query, llm, entities, available_channels)
# Determine filtering based on entities OR context (bot)
include_dm = False
@@ -1062,16 +993,8 @@ def slack_retrieval(
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
)
# Build search tasks — direct thread fetches + keyword searches
search_tasks: list[tuple] = [
(
_fetch_thread_from_url,
(fetch, access_token, channel_metadata_dict),
)
for fetch in direct_fetches
]
search_tasks.extend(
# Build search tasks
search_tasks = [
(
query_slack,
(
@@ -1087,7 +1010,7 @@ def slack_retrieval(
),
)
for query_string in query_strings
)
]
# If include_dm is True AND we're not already searching all channels,
# add additional searches without channel filters.

View File

@@ -10,7 +10,6 @@ from pydantic import ValidationError
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.models import ChunkIndexRequest
from onyx.federated_connectors.slack.models import SlackEntities
from onyx.llm.interfaces import LLM
@@ -639,38 +638,12 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
return [query_text]
SLACK_URL_PATTERN = re.compile(
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
)
def extract_slack_message_urls(
query_text: str,
) -> list[tuple[str, str]]:
"""Extract Slack message URLs from query text.
Parses URLs like:
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
Returns list of (channel_id, thread_ts) tuples.
The 16-digit timestamp is converted to Slack ts format (with dot).
"""
results = []
for match in SLACK_URL_PATTERN.finditer(query_text):
channel_id = match.group(1)
raw_ts = match.group(2)
# Convert p1775491616524769 -> 1775491616.524769
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
results.append((channel_id, thread_ts))
return results
def build_slack_queries(
query: ChunkIndexRequest,
llm: LLM,
entities: dict[str, Any] | None = None,
available_channels: list[str] | None = None,
) -> list[str | DirectThreadFetch]:
) -> list[str]:
"""Build Slack query strings with date filtering and query expansion."""
default_search_days = 30
if entities:
@@ -695,15 +668,6 @@ def build_slack_queries(
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
# Check for Slack message URLs — if found, add direct fetch requests
url_fetches: list[DirectThreadFetch] = []
slack_urls = extract_slack_message_urls(query.query)
for channel_id, thread_ts in slack_urls:
url_fetches.append(
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
)
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
# ALWAYS extract channel references from the query (not just for recency queries)
channel_references = extract_channel_references_from_query(query.query)
@@ -720,9 +684,7 @@ def build_slack_queries(
# If valid channels detected, use ONLY those channels with NO keywords
# Return query with ONLY time filter + channel filter (no keywords)
return url_fetches + [
build_channel_override_query(channel_references, time_filter)
]
return [build_channel_override_query(channel_references, time_filter)]
except ValueError as e:
# If validation fails, log the error and continue with normal flow
logger.warning(f"Channel reference validation failed: {e}")
@@ -740,8 +702,7 @@ def build_slack_queries(
rephrased_queries = expand_query_with_llm(query.query, llm)
# Build final query strings with time filters
search_queries = [
return [
rephrased_query.strip() + time_filter
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
]
return url_fetches + search_queries

View File

@@ -4,6 +4,7 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -54,6 +55,7 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,6 +13,7 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -97,6 +98,11 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
@@ -131,47 +132,32 @@ def get_chat_sessions_by_user(
if before is not None:
stmt = stmt.where(ChatSession.time_updated < before)
if limit:
stmt = stmt.limit(limit)
if project_id is not None:
stmt = stmt.where(ChatSession.project_id == project_id)
elif only_non_project_chats:
stmt = stmt.where(ChatSession.project_id.is_(None))
# When filtering out failed chats, we apply the limit in Python after
# filtering rather than in SQL, since the post-filter may remove rows.
if limit and include_failed_chats:
stmt = stmt.limit(limit)
if not include_failed_chats:
non_system_message_exists_subq = (
exists()
.where(ChatMessage.chat_session_id == ChatSession.id)
.where(ChatMessage.message_type != MessageType.SYSTEM)
.correlate(ChatSession)
)
# Leeway for newly created chats that don't have messages yet
time = datetime.now(timezone.utc) - timedelta(minutes=5)
recently_created = ChatSession.time_created >= time
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
result = db_session.execute(stmt)
chat_sessions = list(result.scalars().all())
chat_sessions = result.scalars().all()
if not include_failed_chats and chat_sessions:
# Filter out "failed" sessions (those with only SYSTEM messages)
# using a separate efficient query instead of a correlated EXISTS
# subquery, which causes full sequential scans of chat_message.
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
if session_ids:
valid_session_ids_stmt = (
select(ChatMessage.chat_session_id)
.where(ChatMessage.chat_session_id.in_(session_ids))
.where(ChatMessage.message_type != MessageType.SYSTEM)
.distinct()
)
valid_session_ids = set(
db_session.execute(valid_session_ids_stmt).scalars().all()
)
chat_sessions = [
cs
for cs in chat_sessions
if cs.time_created >= leeway or cs.id in valid_session_ids
]
if limit:
chat_sessions = chat_sessions[:limit]
return chat_sessions
return list(chat_sessions)
def delete_orphaned_search_docs(db_session: Session) -> None:

View File

@@ -8,8 +8,6 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.constants import FederatedConnectorSource
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
@@ -47,23 +45,6 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
return fetch_all_federated_connectors(db_session)
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
"""Raise if any credential string value contains mask placeholder characters.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
Both must be rejected.
"""
for key, val in credentials.items():
if isinstance(val, str) and (
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
):
raise ValueError(
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
)
def validate_federated_connector_credentials(
source: FederatedConnectorSource,
credentials: dict[str, Any],
@@ -85,8 +66,6 @@ def create_federated_connector(
config: dict[str, Any] | None = None,
) -> FederatedConnector:
"""Create a new federated connector with credential and config validation."""
_reject_masked_credentials(credentials)
# Validate credentials before creating
if not validate_federated_connector_credentials(source, credentials):
raise ValueError(
@@ -298,8 +277,6 @@ def update_federated_connector(
)
if credentials is not None:
_reject_masked_credentials(credentials)
# Validate credentials before updating
if not validate_federated_connector_credentials(
federated_connector.source, credentials

View File

@@ -3135,6 +3135,8 @@ class VoiceProvider(Base):
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -46,6 +47,7 @@ async def fetch_user_for_pat(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
)
if not user:
return None

View File

@@ -229,9 +229,7 @@ def get_memories_for_user(
user_id: UUID,
db_session: Session,
) -> Sequence[Memory]:
return db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
).all()
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
def update_user_pinned_assistants(

View File

@@ -17,30 +17,39 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
db_session.scalars(
select(VoiceProvider)
.where(VoiceProvider.deleted.is_(False))
.order_by(VoiceProvider.name)
).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int
db_session: Session, provider_id: int, include_deleted: bool = False
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
if not include_deleted:
stmt = stmt.where(VoiceProvider.deleted.is_(False))
return db_session.scalar(stmt)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
select(VoiceProvider)
.where(VoiceProvider.is_default_stt.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
select(VoiceProvider)
.where(VoiceProvider.is_default_tts.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
@@ -49,7 +58,9 @@ def fetch_voice_provider_by_type(
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
select(VoiceProvider)
.where(VoiceProvider.provider_type == provider_type)
.where(VoiceProvider.deleted.is_(False))
)
@@ -108,10 +119,10 @@ def upsert_voice_provider(
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Delete a voice provider by ID."""
"""Soft-delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
db_session.delete(provider)
provider.deleted = True
db_session.flush()

View File

@@ -21,7 +21,6 @@ import chardet
import openpyxl
from PIL import Image
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.constants import ONYX_METADATA_FILENAME
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.file_processing.file_types import OnyxFileExtensions
@@ -50,21 +49,9 @@ KNOWN_OPENPYXL_BUGS = [
def get_markitdown_converter() -> "MarkItDown":
global _MARKITDOWN_CONVERTER
from markitdown import MarkItDown
if _MARKITDOWN_CONVERTER is None:
from markitdown import MarkItDown
# Patch this function to effectively no-op because we were seeing this
# module take an inordinate amount of time to convert charts to markdown,
# making some powerpoint files with many or complicated charts nearly
# unindexable.
from markitdown.converters._pptx_converter import PptxConverter
setattr(
PptxConverter,
"_convert_chart_to_markdown",
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
)
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
return _MARKITDOWN_CONVERTER
@@ -189,56 +176,6 @@ def read_text_file(
return file_content_raw, metadata
def count_pdf_embedded_images(file: IO[Any], cap: int) -> int:
"""Return the number of embedded images in a PDF, short-circuiting at cap+1.
Used to reject PDFs whose image count would OOM the user-file-processing
worker during indexing. Returns a value > cap as a sentinel once the count
exceeds the cap, so callers do not iterate thousands of image objects just
to report a number. Returns 0 if the PDF cannot be parsed.
Owner-password-only PDFs (permission restrictions but no open password) are
counted normally — they decrypt with an empty string. Truly password-locked
PDFs are skipped (return 0) since we can't inspect them; the caller should
ensure the password-protected check runs first.
Always restores the file pointer to its original position before returning.
"""
from pypdf import PdfReader
try:
start_pos = file.tell()
except Exception:
start_pos = None
try:
if start_pos is not None:
file.seek(0)
reader = PdfReader(file)
if reader.is_encrypted:
# Try empty password first (owner-password-only PDFs); give up if that fails.
try:
if reader.decrypt("") == 0:
return 0
except Exception:
return 0
count = 0
for page in reader.pages:
for _ in page.images:
count += 1
if count > cap:
return count
return count
except Exception:
logger.warning("Failed to count embedded images in PDF", exc_info=True)
return 0
finally:
if start_pos is not None:
try:
file.seek(start_pos)
except Exception:
pass
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
"""
Extract text from a PDF. For embedded images, a more complex approach is needed.
@@ -265,26 +202,18 @@ def read_pdf_file(
try:
pdf_reader = PdfReader(file)
if pdf_reader.is_encrypted:
# Try the explicit password first, then fall back to an empty
# string. Owner-password-only PDFs (permission restrictions but
# no open password) decrypt successfully with "".
# See https://github.com/onyx-dot-app/onyx/issues/9754
passwords = [p for p in [pdf_pass, ""] if p is not None]
if pdf_reader.is_encrypted and pdf_pass is not None:
decrypt_success = False
for pw in passwords:
try:
if pdf_reader.decrypt(pw) != 0:
decrypt_success = True
break
except Exception:
pass
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error("Unable to decrypt pdf")
if not decrypt_success:
logger.error(
"Encrypted PDF could not be decrypted, returning empty text."
)
return "", metadata, []
elif pdf_reader.is_encrypted:
logger.warning("No Password for an encrypted PDF, returning empty text.")
return "", metadata, []
# Basic PDF metadata
if pdf_reader.metadata is not None:
@@ -302,27 +231,8 @@ def read_pdf_file(
)
if extract_images:
image_cap = MAX_EMBEDDED_IMAGES_PER_FILE
images_processed = 0
cap_reached = False
for page_num, page in enumerate(pdf_reader.pages):
if cap_reached:
break
for image_file_object in page.images:
if images_processed >= image_cap:
# Defense-in-depth backstop. Upload-time validation
# should have rejected files exceeding the cap, but
# we also break here so a single oversized file can
# never pin a worker.
logger.warning(
"PDF embedded image cap reached (%d). "
"Skipping remaining images on page %d and beyond.",
image_cap,
page_num + 1,
)
cap_reached = True
break
image = Image.open(io.BytesIO(image_file_object.data))
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format=image.format)
@@ -335,7 +245,6 @@ def read_pdf_file(
image_callback(img_bytes, image_name)
else:
extracted_images.append((img_bytes, image_name))
images_processed += 1
return text, metadata, extracted_images

View File

@@ -33,20 +33,8 @@ def is_pdf_protected(file: IO[Any]) -> bool:
with preserve_position(file):
reader = PdfReader(file)
if not reader.is_encrypted:
return False
# PDFs with only an owner password (permission restrictions like
# print/copy disabled) use an empty user password — any viewer can open
# them without prompting. decrypt("") returns 0 only when a real user
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
try:
return reader.decrypt("") == 0
except Exception:
logger.exception(
"Failed to evaluate PDF encryption; treating as password protected"
)
return True
return bool(reader.is_encrypted)
def is_docx_protected(file: IO[Any]) -> bool:

View File

@@ -26,7 +26,6 @@ class LlmProviderNames(str, Enum):
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
BIFROST = "bifrost"
OPENAI_COMPATIBLE = "openai_compatible"
def __str__(self) -> str:
"""Needed so things like:
@@ -47,7 +46,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
]
@@ -66,7 +64,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
LlmProviderNames.BIFROST: "Bifrost",
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -119,7 +116,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
# Model family name mappings for display name generation

View File

@@ -175,28 +175,6 @@ def _strip_tool_content_from_messages(
return result
def _fix_tool_user_message_ordering(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Insert a synthetic assistant message between tool and user messages.
Some models (e.g. Mistral on Azure) require strict message ordering where
a user message cannot immediately follow a tool message. This function
inserts a minimal assistant message to bridge the gap.
"""
if len(messages) < 2:
return messages
result: list[dict[str, Any]] = [messages[0]]
for msg in messages[1:]:
prev_role = result[-1].get("role")
curr_role = msg.get("role")
if prev_role == "tool" and curr_role == "user":
result.append({"role": "assistant", "content": "Noted. Continuing."})
result.append(msg)
return result
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
"""Check if any messages contain tool-related content blocks."""
for msg in messages:
@@ -207,21 +185,6 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
return False
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
"""Check if the prompt contains any assistant messages with tool_calls.
When Anthropic's extended thinking is enabled, the API requires every
assistant message to start with a thinking block before any tool_use
blocks. Since we don't preserve thinking_blocks (they carry
cryptographic signatures that can't be reconstructed), we must skip
the thinking param whenever history contains prior tool-calling turns.
"""
from onyx.llm.models import AssistantMessage
msgs = prompt if isinstance(prompt, list) else [prompt]
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
@@ -327,19 +290,12 @@ class LitellmLLM(LLM):
):
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
# model names directly to the endpoint. We route through LiteLLM's
# openai provider with the server's base URL, and ensure /v1 is appended.
if model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
):
# Bifrost: OpenAI-compatible proxy that expects model names in
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
# We route through LiteLLM's openai provider with the Bifrost base URL,
# and ensure /v1 is appended.
if model_provider == LlmProviderNames.BIFROST:
self._custom_llm_provider = "openai"
# LiteLLM's OpenAI client requires an api_key to be set.
# Many OpenAI-compatible servers don't need auth, so supply a
# placeholder to prevent LiteLLM from raising AuthenticationError.
if not self._api_key:
model_kwargs.setdefault("api_key", "not-needed")
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
@@ -456,20 +412,17 @@ class LitellmLLM(LLM):
optional_kwargs: dict[str, Any] = {}
# Model name
is_openai_compatible_proxy = self._model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
)
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
model_provider = (
f"{self.config.model_provider}/responses"
if is_openai_model # Uses litellm's completions -> responses bridge
else self.config.model_provider
)
if is_openai_compatible_proxy:
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
# servers) expect model names sent directly to their endpoint.
# We use custom_llm_provider="openai" so LiteLLM doesn't try
# to route based on the provider prefix.
if is_bifrost:
# Bifrost expects model names in provider/model format
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
# so LiteLLM doesn't try to route based on the provider prefix.
model = self.config.deployment_name or self.config.model_name
else:
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
@@ -513,20 +466,7 @@ class LitellmLLM(LLM):
reasoning_effort
)
# Anthropic requires every assistant message with tool_use
# blocks to start with a thinking block that carries a
# cryptographic signature. We don't preserve those blocks
# across turns, so skip thinking when the history already
# contains tool-calling assistant messages. LiteLLM's
# modify_params workaround doesn't cover all providers
# (notably Bedrock).
can_enable_thinking = (
budget_tokens is not None
and not _prompt_contains_tool_call_history(prompt)
)
if can_enable_thinking:
assert budget_tokens is not None # mypy
if budget_tokens is not None:
if max_tokens is not None:
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
# and the minimum budget tokens is 1024
@@ -560,10 +500,7 @@ class LitellmLLM(LLM):
if structured_response_format:
optional_kwargs["response_format"] = structured_response_format
if (
not (is_claude_model or is_ollama or is_mistral)
or is_openai_compatible_proxy
):
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
# However, this param breaks Anthropic and Mistral models,
# so it must be conditionally included unless the request is
@@ -611,18 +548,6 @@ class LitellmLLM(LLM):
):
messages = _strip_tool_content_from_messages(messages)
# Some models (e.g. Mistral) reject a user message
# immediately after a tool message. Insert a synthetic
# assistant bridge message to satisfy the ordering
# constraint. Check both the provider and the deployment/
# model name to catch Mistral hosted on Azure.
model_or_deployment = (
self._deployment_name or self._model_version or ""
).lower()
is_mistral_model = is_mistral or "mistral" in model_or_deployment
if is_mistral_model:
messages = _fix_tool_user_message_ordering(messages)
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
# reject requests where tool_choice is explicitly null.
if tools and tool_choice is not None:

View File

@@ -15,8 +15,6 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
BIFROST_PROVIDER_NAME = "bifrost"
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -19,7 +19,6 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
@@ -52,7 +51,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
}
@@ -338,7 +336,6 @@ def get_provider_display_name(provider_name: str) -> str:
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
OPENROUTER_PROVIDER_NAME: "OpenRouter",
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
}
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:

View File

@@ -6,7 +6,6 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
from onyx.configs.app_configs import MCP_SERVER_HOST
from onyx.configs.app_configs import MCP_SERVER_PORT
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
logger = setup_logger()
@@ -17,7 +16,6 @@ def main() -> None:
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
return
set_is_ee_based_on_env_variable()
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
from onyx.mcp_server.api import mcp_app

View File

@@ -40,8 +40,6 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.client import app as celery_app
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -52,9 +50,6 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentMetadata
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD
from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES
@@ -132,49 +127,6 @@ class DeleteFileResponse(BaseModel):
# =============================================================================
def _looks_like_pdf(filename: str, content_type: str | None) -> bool:
"""True if either the filename or the content-type indicates a PDF.
Client-supplied ``content_type`` can be spoofed (e.g. a PDF uploaded with
``Content-Type: application/octet-stream``), so we also fall back to
extension-based detection via ``mimetypes.guess_type`` on the filename.
"""
if content_type == "application/pdf":
return True
guessed, _ = mimetypes.guess_type(filename)
return guessed == "application/pdf"
def _check_pdf_image_caps(
filename: str, content: bytes, content_type: str | None, batch_total: int
) -> int:
"""Enforce per-file and per-batch embedded-image caps for PDFs.
Returns the number of embedded images in this file (0 for non-PDFs) so
callers can update their running batch total. Raises OnyxError(INVALID_INPUT)
if either cap is exceeded.
"""
if not _looks_like_pdf(filename, content_type):
return 0
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
# Short-circuit at the larger cap so we get a useful count for both checks.
count = count_pdf_embedded_images(BytesIO(content), max(file_cap, batch_cap))
if count > file_cap:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"PDF '{filename}' contains too many embedded images "
f"(more than {file_cap}). Try splitting the document into smaller files.",
)
if batch_total + count > batch_cap:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"Upload would exceed the {batch_cap}-image limit across all "
f"files in this batch. Try uploading fewer image-heavy files at once.",
)
return count
def _sanitize_path(path: str) -> str:
"""Sanitize a file path, removing traversal attempts and normalizing.
@@ -403,7 +355,6 @@ async def upload_files(
uploaded_entries: list[LibraryEntryResponse] = []
total_size = 0
batch_image_total = 0
now = datetime.now(timezone.utc)
# Sanitize the base path
@@ -423,14 +374,6 @@ async def upload_files(
detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB",
)
# Reject PDFs with an unreasonable per-file or per-batch image count
batch_image_total += _check_pdf_image_caps(
filename=file.filename or "unnamed",
content=content,
content_type=file.content_type,
batch_total=batch_image_total,
)
# Validate cumulative storage (existing + this upload batch)
total_size += file_size
if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES:
@@ -529,7 +472,6 @@ async def upload_zip(
uploaded_entries: list[LibraryEntryResponse] = []
total_size = 0
batch_image_total = 0
# Extract zip contents into a subfolder named after the zip file
zip_name = api_sanitize_filename(file.filename or "upload")
@@ -568,36 +510,6 @@ async def upload_zip(
logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size")
continue
# Skip PDFs that would trip the per-file or per-batch image
# cap (would OOM the user-file-processing worker). Matches
# /upload behavior but uses skip-and-warn to stay consistent
# with the zip path's handling of oversized files.
zip_file_name = zip_info.filename.split("/")[-1]
zip_content_type, _ = mimetypes.guess_type(zip_file_name)
if zip_content_type == "application/pdf":
image_count = count_pdf_embedded_images(
BytesIO(file_content),
max(
MAX_EMBEDDED_IMAGES_PER_FILE,
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
),
)
if image_count > MAX_EMBEDDED_IMAGES_PER_FILE:
logger.warning(
"Skipping '%s' - exceeds %d per-file embedded-image cap",
zip_info.filename,
MAX_EMBEDDED_IMAGES_PER_FILE,
)
continue
if batch_image_total + image_count > MAX_EMBEDDED_IMAGES_PER_UPLOAD:
logger.warning(
"Skipping '%s' - would exceed %d per-batch embedded-image cap",
zip_info.filename,
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
)
continue
batch_image_total += image_count
total_size += file_size
# Validate cumulative storage

View File

@@ -44,12 +44,11 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
@@ -142,11 +141,19 @@ def _validate_endpoint(
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# Any timeout (connect, read, or write) means the configured timeout_seconds
# is too low for this endpoint. Report as timeout so the UI directs the user
# to increase the timeout setting.
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: timeout for %s",
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)

View File

@@ -10,12 +10,9 @@ from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
@@ -195,11 +192,6 @@ def categorize_uploaded_files(
except RuntimeError as e:
logger.warning(f"Failed to get current tenant ID: {str(e)}")
# Running total of embedded images across PDFs in this batch. Once the
# aggregate cap is reached, subsequent PDFs in the same upload are
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
batch_image_total = 0
for upload in files:
try:
filename = get_safe_filename(upload)
@@ -261,47 +253,6 @@ def categorize_uploaded_files(
)
continue
# Reject PDFs with an unreasonable number of embedded images
# (either per-file or accumulated across this upload batch).
# A PDF with thousands of embedded images can OOM the
# user-file-processing celery worker because every image is
# decoded with PIL and then sent to the vision LLM.
if extension == ".pdf":
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
# Use the larger of the two caps as the short-circuit
# threshold so we get a useful count for both checks.
# count_pdf_embedded_images restores the stream position.
count = count_pdf_embedded_images(
upload.file, max(file_cap, batch_cap)
)
if count > file_cap:
results.rejected.append(
RejectedFile(
filename=filename,
reason=(
f"PDF contains too many embedded images "
f"(more than {file_cap}). Try splitting "
f"the document into smaller files."
),
)
)
continue
if batch_image_total + count > batch_cap:
results.rejected.append(
RejectedFile(
filename=filename,
reason=(
f"Upload would exceed the "
f"{batch_cap}-image limit across all "
f"files in this batch. Try uploading "
f"fewer image-heavy files at once."
),
)
)
continue
batch_image_total += count
text_content = extract_file_text(
file=upload.file,
file_name=filename,

View File

@@ -74,8 +74,6 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenAICompatibleFinalModelResponse
from onyx.server.manage.llm.models import OpenAICompatibleModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
@@ -1526,7 +1524,6 @@ def get_bifrost_available_models(
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:
@@ -1577,95 +1574,3 @@ def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> d
source_name="Bifrost",
api_key=api_key,
)
@admin_router.post("/openai-compatible/available-models")
def get_openai_compatible_server_available_models(
request: OpenAICompatibleModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[OpenAICompatibleFinalModelResponse]:
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
response_json = _get_openai_compatible_server_response(
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your OpenAI-compatible endpoint",
)
results: list[OpenAICompatibleFinalModelResponse] = []
for model in models:
try:
model_id = model.get("id", "")
model_name = model.get("name", model_id)
if not model_id:
continue
# Skip embedding models
if is_embedding_model(model_id):
continue
results.append(
OpenAICompatibleFinalModelResponse(
name=model_id,
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:
logger.warning(
"Failed to parse OpenAI-compatible model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from OpenAI-compatible endpoint",
)
sorted_results = sorted(results, key=lambda m: m.name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="OpenAI Compatible",
)
return sorted_results
def _get_openai_compatible_server_response(
api_base: str, api_key: str | None = None
) -> dict:
"""Perform GET to an OpenAI-compatible /v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
# Ensure we hit /v1/models
if cleaned_api_base.endswith("/v1"):
url = f"{cleaned_api_base}/models"
else:
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="OpenAI Compatible",
api_key=api_key,
)

View File

@@ -463,19 +463,3 @@ class BifrostFinalModelResponse(BaseModel):
display_name: str # Human-readable name from Bifrost API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool
# OpenAI Compatible dynamic models fetch
class OpenAICompatibleModelsRequest(BaseModel):
api_base: str
api_key: str | None = None
provider_name: str | None = None # Optional: to save models to existing provider
class OpenAICompatibleFinalModelResponse(BaseModel):
name: str # Model ID (e.g. "meta-llama/Llama-3-8B-Instruct")
display_name: str # Human-readable name from API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -26,7 +26,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
)

View File

@@ -147,7 +147,6 @@ class UserInfo(BaseModel):
is_anonymous_user: bool | None = None,
tenant_info: TenantInfo | None = None,
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
memories: list[MemoryItem] | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
@@ -192,7 +191,10 @@ class UserInfo(BaseModel):
role=user.personal_role or "",
use_memories=user.use_memories,
enable_memory_tool=user.enable_memory_tool,
memories=memories or [],
memories=[
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in (user.memories or [])
],
user_preferences=user.user_preferences or "",
),
)

View File

@@ -57,7 +57,6 @@ from onyx.db.user_preferences import activate_user
from onyx.db.user_preferences import deactivate_user
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
from onyx.db.user_preferences import get_latest_access_token_for_user
from onyx.db.user_preferences import get_memories_for_user
from onyx.db.user_preferences import update_assistant_preferences
from onyx.db.user_preferences import update_user_assistant_visibility
from onyx.db.user_preferences import update_user_auto_scroll
@@ -824,11 +823,6 @@ def verify_user_logged_in(
[],
),
)
memories = [
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in get_memories_for_user(user.id, db_session)
]
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
@@ -839,7 +833,6 @@ def verify_user_logged_in(
new_tenant=new_tenant,
invitation=tenant_invitation,
),
memories=memories,
)
return user_info
@@ -937,8 +930,7 @@ def update_user_personalization_api(
else user.enable_memory_tool
)
existing_memories = [
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in get_memories_for_user(user.id, db_session)
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
]
new_memories = (
request.memories if request.memories is not None else existing_memories

View File

@@ -12,6 +12,7 @@ stale, which is fine for monitoring dashboards.
import json
import threading
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -103,23 +104,25 @@ class _CachedCollector(Collector):
class QueueDepthCollector(_CachedCollector):
"""Reads Celery queue lengths from the broker Redis on each scrape."""
"""Reads Celery queue lengths from the broker Redis on each scrape.
Uses a Redis client factory (callable) rather than a stored client
reference so the connection is always fresh from Celery's pool.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
self._get_redis: Callable[[], Redis] | None = None
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
if self._get_redis is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
redis_client = self._get_redis()
depth = GaugeMetricFamily(
"onyx_queue_depth",
@@ -401,19 +404,17 @@ class RedisHealthCollector(_CachedCollector):
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
self._get_redis: Callable[[], Redis] | None = None
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
if self._get_redis is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
redis_client = self._get_redis()
memory_used = GaugeMetricFamily(
"onyx_redis_memory_used_bytes",

View File

@@ -3,8 +3,12 @@
Called once by the monitoring celery worker after Redis and DB are ready.
"""
from collections.abc import Callable
from typing import Any
from celery import Celery
from prometheus_client.registry import REGISTRY
from redis import Redis
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
@@ -17,7 +21,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# Module-level singletons — these are lightweight objects (no connections or DB
# state) until configure() / set_celery_app() is called. Keeping them at
# state) until configure() / set_redis_factory() is called. Keeping them at
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
@@ -28,15 +32,72 @@ _worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
"""Create a factory that returns a cached broker Redis client.
Reuses a single connection across scrapes to avoid leaking connections.
Reconnects automatically if the cached connection becomes stale.
"""
_cached_client: list[Redis | None] = [None]
# Keep a reference to the Kombu Connection so we can close it on
# reconnect (the raw Redis client outlives the Kombu wrapper).
_cached_kombu_conn: list[Any] = [None]
def _close_client(client: Redis) -> None:
"""Best-effort close of a Redis client."""
try:
client.close()
except Exception:
logger.debug("Failed to close stale Redis client", exc_info=True)
def _close_kombu_conn() -> None:
"""Best-effort close of the cached Kombu Connection."""
conn = _cached_kombu_conn[0]
if conn is not None:
try:
conn.close()
except Exception:
logger.debug("Failed to close Kombu connection", exc_info=True)
_cached_kombu_conn[0] = None
def _get_broker_redis() -> Redis:
client = _cached_client[0]
if client is not None:
try:
client.ping()
return client
except Exception:
logger.debug("Cached Redis client stale, reconnecting")
_close_client(client)
_cached_client[0] = None
_close_kombu_conn()
# Get a fresh Redis client from the broker connection.
# We hold this client long-term (cached above) rather than using a
# context manager, because we need it to persist across scrapes.
# The caching logic above ensures we only ever hold one connection,
# and we close it explicitly on reconnect.
conn = celery_app.broker_connection()
# kombu's Channel exposes .client at runtime (the underlying Redis
# client) but the type stubs don't declare it.
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
_cached_client[0] = new_client
_cached_kombu_conn[0] = conn
return new_client
return _get_broker_redis
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
"""Register all indexing pipeline collectors with the default registry.
Args:
celery_app: The Celery application instance. Used to obtain a
celery_app: The Celery application instance. Used to obtain a fresh
broker Redis client on each scrape for queue depth metrics.
"""
_queue_collector.set_celery_app(celery_app)
_redis_health_collector.set_celery_app(celery_app)
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
# Start the heartbeat monitor daemon thread — uses a single persistent
# connection to receive worker-heartbeat events.

View File

@@ -129,10 +129,6 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
return_value=mock_app,
),
patch(_PATCH_QUEUE_DEPTH, return_value=0),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -88,22 +88,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
Also patches ``celery_get_broker_client`` so the mock app doesn't need
a real broker URL.
"""
task_instance = task.run.__self__
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield

View File

@@ -90,17 +90,8 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
task only.
"""
task_instance = task.run.__self__
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield

View File

@@ -103,11 +103,6 @@ _EXPECTED_CONFLUENCE_GROUPS = [
user_emails={"oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="no yuhong allowed",
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
gives_anyone_access=False,
),
]

View File

@@ -1,58 +0,0 @@
import pytest
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.db.federated import _reject_masked_credentials
class TestRejectMaskedCredentials:
"""Verify that masked credential values are never accepted for DB writes.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
_reject_masked_credentials must catch both.
"""
def test_rejects_fully_masked_value(self) -> None:
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
with pytest.raises(ValueError, match="masked placeholder"):
_reject_masked_credentials({"client_id": masked})
def test_rejects_long_string_masked_value(self) -> None:
"""mask_string returns 'first4...last4' for long strings — the real
format used for OAuth credentials like client_id and client_secret."""
with pytest.raises(ValueError, match="masked placeholder"):
_reject_masked_credentials({"client_id": "1234...7890"})
def test_rejects_when_any_field_is_masked(self) -> None:
"""Even if client_id is real, a masked client_secret must be caught."""
with pytest.raises(ValueError, match="client_secret"):
_reject_masked_credentials(
{
"client_id": "1234567890.1234567890",
"client_secret": MASK_CREDENTIAL_CHAR * 12,
}
)
def test_accepts_real_credentials(self) -> None:
# Should not raise
_reject_masked_credentials(
{
"client_id": "1234567890.1234567890",
"client_secret": "test_client_secret_value",
}
)
def test_accepts_empty_dict(self) -> None:
# Should not raise — empty credentials are handled elsewhere
_reject_masked_credentials({})
def test_ignores_non_string_values(self) -> None:
# Non-string values (None, bool, int) should pass through
_reject_masked_credentials(
{
"client_id": "real_value",
"redirect_uri": None,
"some_flag": True,
}
)

View File

@@ -1,87 +0,0 @@
"""Tests for celery_get_broker_client singleton."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.background.celery import celery_redis
@pytest.fixture(autouse=True)
def reset_singleton() -> Iterator[None]:
"""Reset the module-level singleton between tests."""
celery_redis._broker_client = None
celery_redis._broker_url = None
yield
celery_redis._broker_client = None
celery_redis._broker_url = None
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
app = MagicMock()
app.conf.broker_url = broker_url
return app
class TestCeleryGetBrokerClient:
@patch("onyx.background.celery.celery_redis.Redis")
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
result = celery_redis.celery_get_broker_client(app)
assert result is mock_client
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://localhost:6379/15"
assert call_args[1]["decode_responses"] is False
assert call_args[1]["socket_keepalive"] is True
assert call_args[1]["retry_on_timeout"] is True
@patch("onyx.background.celery.celery_redis.Redis")
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
client1 = celery_redis.celery_get_broker_client(app)
client2 = celery_redis.celery_get_broker_client(app)
assert client1 is client2
# from_url called only once
assert mock_redis_cls.from_url.call_count == 1
@patch("onyx.background.celery.celery_redis.Redis")
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
stale_client = MagicMock()
stale_client.ping.side_effect = ConnectionError("disconnected")
fresh_client = MagicMock()
fresh_client.ping.return_value = True
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
app = _make_mock_app()
# First call creates stale_client
client1 = celery_redis.celery_get_broker_client(app)
assert client1 is stale_client
# Second call: ping fails, creates fresh_client
client2 = celery_redis.celery_get_broker_client(app)
assert client2 is fresh_client
assert mock_redis_cls.from_url.call_count == 2
@patch("onyx.background.celery.celery_redis.Redis")
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
mock_redis_cls.from_url.return_value = MagicMock()
app = _make_mock_app("redis://custom-host:6380/3")
celery_redis.celery_get_broker_client(app)
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://custom-host:6380/3"

View File

@@ -1,147 +0,0 @@
from typing import Any
from unittest.mock import MagicMock
import pytest
import requests
from jira import JIRA
from jira.resources import Issue
from onyx.connectors.jira.connector import bulk_fetch_issues
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
return {
"id": issue_id,
"key": f"TEST-{issue_id}",
"fields": {"summary": f"Issue {issue_id}"},
}
def _mock_jira_client() -> MagicMock:
mock = MagicMock(spec=JIRA)
mock._options = {"server": "https://jira.example.com"}
mock._session = MagicMock()
mock._get_url = MagicMock(
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
)
return mock
def test_bulk_fetch_success() -> None:
"""Happy path: all issues fetched in one request."""
client = _mock_jira_client()
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
result = bulk_fetch_issues(client, ["1", "2", "3"])
assert len(result) == 3
assert all(isinstance(r, Issue) for r in result)
client._session.post.assert_called_once()
def test_bulk_fetch_splits_on_json_error() -> None:
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
client = _mock_jira_client()
call_count = 0
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
nonlocal call_count
call_count += 1
ids = json["issueIdsOrKeys"]
if len(ids) > 2:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 2294125
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
assert len(result) == 4
returned_ids = {r.raw["id"] for r in result}
assert returned_ids == {"1", "2", "3", "4"}
assert call_count > 1
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
"""A single issue that always fails JSON decode raises after splitting."""
client = _mock_jira_client()
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if "bad" in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 100
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "bad", "2"])
def test_bulk_fetch_non_json_error_propagates() -> None:
"""Non-JSONDecodeError exceptions still propagate."""
client = _mock_jira_client()
resp = MagicMock()
resp.json.side_effect = ValueError("something else broke")
client._session.post.return_value = resp
try:
bulk_fetch_issues(client, ["1"])
assert False, "Expected ValueError to propagate"
except ValueError:
pass
def test_bulk_fetch_with_fields() -> None:
"""Fields parameter is forwarded correctly."""
client = _mock_jira_client()
raw = [_make_raw_issue("1")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
bulk_fetch_issues(client, ["1"], fields="summary,description")
call_payload = client._session.post.call_args[1]["json"]
assert call_payload["fields"] == ["summary", "description"]
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
client = _mock_jira_client()
bad_id = "BAD"
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if bad_id in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"truncated", "doc", 999
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])

View File

@@ -1,67 +0,0 @@
"""Tests for _build_thread_text function."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.context.search.federated.slack_search import _build_thread_text
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
return {"user": user, "text": text, "ts": ts}
class TestBuildThreadText:
"""Verify _build_thread_text includes full thread replies up to cap."""
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
"""All replies within cap are included in output."""
mock_profiles.return_value = {}
messages = [
_make_msg("U1", "parent msg", "1000.0"),
_make_msg("U2", "reply 1", "1001.0"),
_make_msg("U3", "reply 2", "1002.0"),
_make_msg("U4", "reply 3", "1003.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "parent msg" in result
assert "reply 1" in result
assert "reply 2" in result
assert "reply 3" in result
assert "..." not in result
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
"""Single message (no replies) returns just the parent text."""
mock_profiles.return_value = {}
messages = [_make_msg("U1", "just a message", "1000.0")]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "just a message" in result
assert "Replies:" not in result
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
"""Thread parent message is always the first line of output."""
mock_profiles.return_value = {}
messages = [
_make_msg("U1", "I am the parent", "1000.0"),
_make_msg("U2", "I am a reply", "1001.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
parent_pos = result.index("I am the parent")
reply_pos = result.index("I am a reply")
assert parent_pos < reply_pos
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
"""User IDs in thread text are replaced with display names."""
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
messages = [
_make_msg("U1", "hello", "1000.0"),
_make_msg("U2", "world", "1001.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "Alice" in result
assert "Bob" in result
assert "<@U1>" not in result
assert "<@U2>" not in result

View File

@@ -1,108 +0,0 @@
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
class TestExtractSlackMessageUrls:
"""Verify URL parsing extracts channel_id and timestamp correctly."""
def test_standard_url(self) -> None:
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
results = extract_slack_message_urls(query)
assert len(results) == 1
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
def test_multiple_urls(self) -> None:
query = (
"compare https://co.slack.com/archives/C111/p1234567890123456 "
"and https://co.slack.com/archives/C222/p9876543210987654"
)
results = extract_slack_message_urls(query)
assert len(results) == 2
assert results[0] == ("C111", "1234567890.123456")
assert results[1] == ("C222", "9876543210.987654")
def test_no_urls(self) -> None:
query = "what happened in #general last week?"
results = extract_slack_message_urls(query)
assert len(results) == 0
def test_non_slack_url_ignored(self) -> None:
query = "check https://google.com/archives/C111/p1234567890123456"
results = extract_slack_message_urls(query)
assert len(results) == 0
def test_timestamp_conversion(self) -> None:
"""p prefix removed, dot inserted after 10th digit."""
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
results = extract_slack_message_urls(query)
channel_id, ts = results[0]
assert channel_id == "CABC123"
assert ts == "1775491616.524769"
assert not ts.startswith("p")
assert "." in ts
class TestFetchThreadFromUrl:
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
@patch("onyx.context.search.federated.slack_search._build_thread_text")
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_successful_fetch(
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
) -> None:
mock_client = MagicMock()
mock_webclient_cls.return_value = mock_client
# Mock conversations_replies
mock_response = MagicMock()
mock_response.get.return_value = [
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
]
mock_client.conversations_replies.return_value = mock_response
# Mock channel info
mock_ch_response = MagicMock()
mock_ch_response.get.return_value = {"name": "general"}
mock_client.conversations_info.return_value = mock_ch_response
mock_build_thread.return_value = (
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
)
fetch = DirectThreadFetch(
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
)
result = _fetch_thread_from_url(fetch, "xoxp-token")
assert len(result.messages) == 1
msg = result.messages[0]
assert msg.channel_id == "C097NBWMY8Y"
assert msg.thread_id is None # Prevents double-enrichment
assert msg.slack_score == 100000.0
assert "parent" in msg.text
mock_client.conversations_replies.assert_called_once_with(
channel="C097NBWMY8Y", ts="1775491616.524769"
)
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
from slack_sdk.errors import SlackApiError
mock_client = MagicMock()
mock_webclient_cls.return_value = mock_client
mock_client.conversations_replies.side_effect = SlackApiError(
message="channel_not_found",
response=MagicMock(status_code=404),
)
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
result = _fetch_thread_from_url(fetch, "xoxp-token")
assert len(result.messages) == 0

View File

@@ -1,225 +0,0 @@
"""Tests for get_chat_sessions_by_user filtering behavior.
Verifies that failed chat sessions (those with only SYSTEM messages) are
correctly filtered out while preserving recently created sessions, matching
the behavior specified in PR #7233.
"""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from uuid import UUID
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.models import ChatSession
def _make_session(
user_id: UUID,
time_created: datetime | None = None,
time_updated: datetime | None = None,
description: str = "",
) -> MagicMock:
"""Create a mock ChatSession with the given attributes."""
session = MagicMock(spec=ChatSession)
session.id = uuid4()
session.user_id = user_id
session.time_created = time_created or datetime.now(timezone.utc)
session.time_updated = time_updated or session.time_created
session.description = description
session.deleted = False
session.onyxbot_flow = False
session.project_id = None
return session
@pytest.fixture
def user_id() -> UUID:
return uuid4()
@pytest.fixture
def old_time() -> datetime:
"""A timestamp well outside the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(hours=1)
@pytest.fixture
def recent_time() -> datetime:
"""A timestamp within the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(minutes=2)
class TestGetChatSessionsByUser:
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
def test_filters_out_failed_sessions(
self, user_id: UUID, old_time: datetime
) -> None:
"""Sessions with only SYSTEM messages should be excluded."""
valid_session = _make_session(user_id, time_created=old_time)
failed_session = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
# First execute: returns all sessions
# Second execute: returns only the valid session's ID (has non-system msgs)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
valid_session,
failed_session,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == valid_session.id
def test_keeps_recent_sessions_without_messages(
self, user_id: UUID, recent_time: datetime
) -> None:
"""Recently created sessions should be kept even without messages."""
recent_session = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [recent_session]
db_session.execute.side_effect = [mock_result_1]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == recent_session.id
# Should only have been called once — no second query needed
# because the recent session is within the leeway window
assert db_session.execute.call_count == 1
def test_include_failed_chats_skips_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""When include_failed_chats=True, no filtering should occur."""
session_a = _make_session(user_id, time_created=old_time)
session_b = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=True,
)
assert len(result) == 2
# Only one DB call — no second query for message validation
assert db_session.execute.call_count == 1
def test_limit_applied_after_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""Limit should be applied after filtering, not before."""
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
valid_ids = [s.id for s in sessions[:3]]
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = sessions
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = valid_ids
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
limit=2,
)
assert len(result) == 2
# Should be the first 2 valid sessions (order preserved)
assert result[0].id == sessions[0].id
assert result[1].id == sessions[1].id
def test_mixed_recent_and_old_sessions(
self, user_id: UUID, old_time: datetime, recent_time: datetime
) -> None:
"""Mix of recent and old sessions should filter correctly."""
old_valid = _make_session(user_id, time_created=old_time)
old_failed = _make_session(user_id, time_created=old_time)
recent_no_msgs = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
old_valid,
old_failed,
recent_no_msgs,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
result_ids = {cs.id for cs in result}
assert old_valid.id in result_ids
assert recent_no_msgs.id in result_ids
assert old_failed.id not in result_ids
def test_empty_result(self, user_id: UUID) -> None:
"""No sessions should return empty list without errors."""
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert result == []
assert db_session.execute.call_count == 1

View File

@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
class TestDeleteVoiceProvider:
"""Tests for delete_voice_provider."""
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
delete_voice_provider(mock_db_session, 1)
mock_db_session.delete.assert_called_once_with(provider)
assert provider.deleted is True
mock_db_session.flush.assert_called_once()
def test_does_nothing_when_provider_not_found(

View File

@@ -1,76 +0,0 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer <1083d595b1>
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 42
>>
stream
,N<><6~<7E>)<29><><EFBFBD><EFBFBD><EFBFBD>u<EFBFBD> <0C><><EFBFBD>Zc'<27><>>8g<38><67><EFBFBD>n<EFBFBD><6E><EFBFBD><EFBFBD><EFBFBD>9"
endstream
endobj
6 0 obj
<<
/V 2
/R 3
/Length 128
/P 4294967292
/Filter /Standard
/O <6a340a292629053da84a6d8b19a5d505953b8b3fdac3d2d389fde0e354528d44>
/U <d6f0dc91c7b9de264a8d708515468e6528bf4e5e4e758a4164004e56fffa0108>
>>
endobj
xref
0 7
0000000000 65535 f
0000000015 00000 n
0000000059 00000 n
0000000118 00000 n
0000000167 00000 n
0000000348 00000 n
0000000440 00000 n
trailer
<<
/Size 7
/Root 3 0 R
/Info 1 0 R
/ID [ <6364336635356135633239323638353039306635656133623165313637366430> <6364336635356135633239323638353039306635656133623165313637366430> ]
/Encrypt 6 0 R
>>
startxref
655
%%EOF

View File

@@ -12,10 +12,6 @@ dependency on pypdf internals (pypdf.generic).
from io import BytesIO
from pathlib import Path
import pytest
from onyx.file_processing import extract_file_text
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.file_processing.extract_file_text import pdf_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.password_validation import is_pdf_protected
@@ -58,12 +54,6 @@ class TestReadPdfFile:
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
assert text == ""
def test_owner_password_only_pdf_extracts_text(self) -> None:
"""A PDF encrypted with only an owner password (no user password)
should still yield its text content. Regression for #9754."""
text, _, _ = read_pdf_file(_load("owner_protected.pdf"))
assert "Hello World" in text
def test_empty_pdf(self) -> None:
text, _, _ = read_pdf_file(_load("empty.pdf"))
assert text.strip() == ""
@@ -100,80 +90,6 @@ class TestReadPdfFile:
# Returned list is empty when callback is used
assert images == []
def test_image_cap_skips_images_above_limit(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When the embedded-image cap is exceeded, remaining images are skipped.
The cap protects the user-file-processing worker from OOMing on PDFs
with thousands of embedded images. Setting the cap to 0 should yield
zero extracted images even though the fixture has one.
"""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
assert images == []
def test_image_cap_at_limit_extracts_up_to_cap(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A cap >= image count behaves identically to the uncapped path."""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 100)
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
assert len(images) == 1
def test_image_cap_with_callback_stops_streaming_at_limit(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The cap also short-circuits the streaming callback path."""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
collected: list[tuple[bytes, str]] = []
def callback(data: bytes, name: str) -> None:
collected.append((data, name))
read_pdf_file(
_load("with_image.pdf"), extract_images=True, image_callback=callback
)
assert collected == []
# ── count_pdf_embedded_images ────────────────────────────────────────────
class TestCountPdfEmbeddedImages:
def test_returns_count_for_normal_pdf(self) -> None:
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=10) == 1
def test_short_circuits_above_cap(self) -> None:
# with_image.pdf has 1 image. cap=0 means "anything > 0 is over cap" —
# function returns on first increment as the over-cap sentinel.
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=0) == 1
def test_returns_zero_for_pdf_without_images(self) -> None:
assert count_pdf_embedded_images(_load("simple.pdf"), cap=10) == 0
def test_returns_zero_for_invalid_pdf(self) -> None:
assert count_pdf_embedded_images(BytesIO(b"not a pdf"), cap=10) == 0
def test_returns_zero_for_password_locked_pdf(self) -> None:
# encrypted.pdf has an open password; we can't inspect without it, so
# the helper returns 0 — callers rely on the password-protected check
# that runs earlier in the upload pipeline.
assert count_pdf_embedded_images(_load("encrypted.pdf"), cap=10) == 0
def test_inspects_owner_password_only_pdf(self) -> None:
# owner_protected.pdf is encrypted but has no open password. It should
# decrypt with an empty string and count images normally. The fixture
# has zero images, so 0 is a real count (not the "bail on encrypted"
# path).
assert count_pdf_embedded_images(_load("owner_protected.pdf"), cap=10) == 0
def test_preserves_file_position(self) -> None:
pdf = _load("with_image.pdf")
pdf.seek(42)
count_pdf_embedded_images(pdf, cap=10)
assert pdf.tell() == 42
# ── pdf_to_text ──────────────────────────────────────────────────────────
@@ -201,12 +117,6 @@ class TestIsPdfProtected:
def test_protected_pdf(self) -> None:
assert is_pdf_protected(_load("encrypted.pdf")) is True
def test_owner_password_only_is_not_protected(self) -> None:
"""A PDF with only an owner password (permission restrictions) but no
user password should NOT be considered protected — any viewer can open
it without prompting for a password."""
assert is_pdf_protected(_load("owner_protected.pdf")) is False
def test_preserves_file_position(self) -> None:
pdf = _load("simple.pdf")
pdf.seek(42)

View File

@@ -1,79 +0,0 @@
import io
from pptx import Presentation # type: ignore[import-untyped]
from pptx.chart.data import CategoryChartData # type: ignore[import-untyped]
from pptx.enum.chart import XL_CHART_TYPE # type: ignore[import-untyped]
from pptx.util import Inches # type: ignore[import-untyped]
from onyx.file_processing.extract_file_text import pptx_to_text
def _make_pptx_with_chart() -> io.BytesIO:
"""Create an in-memory pptx with one text slide and one chart slide."""
prs = Presentation()
# Slide 1: text only
slide1 = prs.slides.add_slide(prs.slide_layouts[1])
slide1.shapes.title.text = "Introduction"
slide1.placeholders[1].text = "This is the first slide."
# Slide 2: chart
slide2 = prs.slides.add_slide(prs.slide_layouts[5]) # Blank layout
chart_data = CategoryChartData()
chart_data.categories = ["Q1", "Q2", "Q3"]
chart_data.add_series("Revenue", (100, 200, 300))
slide2.shapes.add_chart(
XL_CHART_TYPE.COLUMN_CLUSTERED,
Inches(1),
Inches(1),
Inches(6),
Inches(4),
chart_data,
)
buf = io.BytesIO()
prs.save(buf)
buf.seek(0)
return buf
def _make_pptx_without_chart() -> io.BytesIO:
"""Create an in-memory pptx with a single text-only slide."""
prs = Presentation()
slide = prs.slides.add_slide(prs.slide_layouts[1])
slide.shapes.title.text = "Hello World"
slide.placeholders[1].text = "Some content here."
buf = io.BytesIO()
prs.save(buf)
buf.seek(0)
return buf
class TestPptxToText:
def test_chart_is_omitted(self) -> None:
# Precondition
pptx_file = _make_pptx_with_chart()
# Under test
result = pptx_to_text(pptx_file)
# Postcondition
assert "Introduction" in result
assert "first slide" in result
assert "[chart omitted]" in result
# The actual chart data should NOT appear in the output.
assert "Revenue" not in result
assert "Q1" not in result
def test_text_only_pptx(self) -> None:
# Precondition
pptx_file = _make_pptx_without_chart()
# Under test
result = pptx_to_text(pptx_file)
# Postcondition
assert "Hello World" in result
assert "Some content" in result
assert "[chart omitted]" not in result

View File

@@ -11,7 +11,6 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
from litellm.types.utils import Delta
from litellm.types.utils import Function as LiteLLMFunction
import onyx.llm.models
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLMUserIdentity
@@ -1480,147 +1479,6 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
def test_prompt_contains_tool_call_history_true() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="tc_1",
function=FunctionCall(name="get_weather", arguments="{}"),
)
],
),
]
assert _prompt_contains_tool_call_history(messages) is True
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [
UserMessage(content="Hello"),
AssistantMessage(content="Hi there!"),
]
assert _prompt_contains_tool_call_history(messages) is False
def test_prompt_contains_tool_call_history_false_user_only() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [UserMessage(content="Hello")]
assert _prompt_contains_tool_call_history(messages) is False
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
"""When thinking is enabled but assistant messages with tool_calls lack
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
BadRequestError about missing thinking blocks."""
llm = LitellmLLM(
api_key=None,
timeout=30,
model_provider=LlmProviderNames.BEDROCK,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
max_input_tokens=200000,
)
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="tc_1",
function=FunctionCall(
name="get_weather",
arguments='{"city": "Paris"}',
),
)
],
),
onyx.llm.models.ToolMessage(
content="22°C sunny",
tool_call_id="tc_1",
),
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
kwargs = mock_completion.call_args.kwargs
assert "thinking" not in kwargs, (
"thinking param should be dropped when thinking_blocks are missing "
"from assistant messages with tool_calls"
)
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
"""When thinking is enabled and there are no historical assistant messages
with tool_calls, the thinking param should be preserved."""
llm = LitellmLLM(
api_key=None,
timeout=30,
model_provider=LlmProviderNames.BEDROCK,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
max_input_tokens=200000,
)
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
kwargs = mock_completion.call_args.kwargs
assert "thinking" in kwargs, (
"thinking param should be preserved when no assistant messages "
"with tool_calls exist in history"
)
assert kwargs["thinking"]["type"] == "enabled"
def test_bifrost_claude_includes_allowed_openai_params() -> None:
llm = LitellmLLM(
api_key="test_key",

View File

@@ -3,7 +3,7 @@
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
# ---------------------------------------------------------------------------
@@ -158,11 +158,13 @@ class TestValidateEndpoint:
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.timeout
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
@@ -10,9 +9,7 @@ from uuid import uuid4
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.server.scim.api import _check_seat_availability
from ee.onyx.server.scim.api import _scim_name_to_str
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
from ee.onyx.server.scim.api import create_user
from ee.onyx.server.scim.api import delete_user
from ee.onyx.server.scim.api import get_user
@@ -744,80 +741,3 @@ class TestEmailCasePreservation:
resource = parse_scim_user(result)
assert resource.userName == "Alice@Example.COM"
assert resource.emails[0].value == "Alice@Example.COM"
class TestSeatLock:
"""Tests for the advisory lock in _check_seat_availability."""
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
def test_acquires_advisory_lock_before_checking(
self,
_mock_tenant: MagicMock,
mock_dal: MagicMock,
) -> None:
"""The advisory lock must be acquired before the seat check runs."""
call_order: list[str] = []
def track_execute(stmt: Any, _params: Any = None) -> None:
if "pg_advisory_xact_lock" in str(stmt):
call_order.append("lock")
mock_dal.session.execute.side_effect = track_execute
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
) as mock_fetch:
mock_result = MagicMock()
mock_result.available = True
mock_fn = MagicMock(return_value=mock_result)
mock_fetch.return_value = mock_fn
def track_check(*_args: Any, **_kwargs: Any) -> Any:
call_order.append("check")
return mock_result
mock_fn.side_effect = track_check
_check_seat_availability(mock_dal)
assert call_order == ["lock", "check"]
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
def test_lock_uses_tenant_scoped_key(
self,
_mock_tenant: MagicMock,
mock_dal: MagicMock,
) -> None:
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
mock_result = MagicMock()
mock_result.available = True
mock_check = MagicMock(return_value=mock_result)
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
return_value=mock_check,
):
_check_seat_availability(mock_dal)
mock_dal.session.execute.assert_called_once()
params = mock_dal.session.execute.call_args[0][1]
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
"""Lock id must be deterministic and differ across tenants."""
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
def test_no_lock_when_ee_absent(
self,
mock_dal: MagicMock,
) -> None:
"""No advisory lock should be acquired when the EE check is absent."""
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
return_value=None,
):
result = _check_seat_availability(mock_dal)
assert result is None
mock_dal.session.execute.assert_not_called()

View File

@@ -1,6 +1,5 @@
"""Tests for indexing pipeline Prometheus collectors."""
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -14,16 +13,6 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
@pytest.fixture(autouse=True)
def _mock_broker_client() -> Iterator[None]:
"""Patch celery_get_broker_client for all collector tests."""
with patch(
"onyx.background.celery.celery_redis.celery_get_broker_client",
return_value=MagicMock(),
):
yield
class TestQueueDepthCollector:
def test_returns_empty_when_factory_not_set(self) -> None:
collector = QueueDepthCollector()
@@ -35,7 +24,8 @@ class TestQueueDepthCollector:
def test_collects_queue_depths(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
collector.set_celery_app(MagicMock())
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with (
patch(
@@ -70,8 +60,8 @@ class TestQueueDepthCollector:
def test_handles_redis_error_gracefully(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
MagicMock()
collector.set_celery_app(MagicMock())
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
@@ -84,8 +74,8 @@ class TestQueueDepthCollector:
def test_caching_returns_stale_within_ttl(self) -> None:
collector = QueueDepthCollector(cache_ttl=60)
MagicMock()
collector.set_celery_app(MagicMock())
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with (
patch(
@@ -108,10 +98,31 @@ class TestQueueDepthCollector:
assert first is second # Same object, from cache
def test_factory_called_each_scrape(self) -> None:
"""Verify the Redis factory is called on each fresh collect, not cached."""
collector = QueueDepthCollector(cache_ttl=0)
factory = MagicMock(return_value=MagicMock())
collector.set_redis_factory(factory)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=0,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
collector.collect()
collector.collect()
assert factory.call_count == 2
def test_error_returns_stale_cache(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
MagicMock()
collector.set_celery_app(MagicMock())
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
# First call succeeds
with (

View File

@@ -1,22 +1,96 @@
"""Tests for indexing pipeline setup."""
"""Tests for indexing pipeline setup (Redis factory caching)."""
from unittest.mock import MagicMock
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
class TestCollectorCeleryAppSetup:
def test_queue_depth_collector_uses_celery_app(self) -> None:
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
collector = QueueDepthCollector()
mock_app = MagicMock()
collector.set_celery_app(mock_app)
assert collector._celery_app is mock_app
def _make_mock_app(client: MagicMock) -> MagicMock:
"""Create a mock Celery app whose broker_connection().channel().client
returns the given client."""
mock_app = MagicMock()
mock_conn = MagicMock()
mock_conn.channel.return_value.client = client
def test_redis_health_collector_uses_celery_app(self) -> None:
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
collector = RedisHealthCollector()
mock_app = MagicMock()
collector.set_celery_app(mock_app)
assert collector._celery_app is mock_app
mock_app.broker_connection.return_value = mock_conn
return mock_app
class TestMakeBrokerRedisFactory:
def test_caches_redis_client_across_calls(self) -> None:
"""Factory should reuse the same client on subsequent calls."""
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client1 = factory()
client2 = factory()
assert client1 is client2
# broker_connection should only be called once
assert mock_app.broker_connection.call_count == 1
def test_reconnects_when_ping_fails(self) -> None:
"""Factory should create a new client if ping fails (stale connection)."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
client1 = factory()
assert client1 is mock_client_stale
assert mock_app.broker_connection.call_count == 1
# Switch to fresh client for next connection
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails on stale, reconnects
client2 = factory()
assert client2 is mock_client_fresh
assert mock_app.broker_connection.call_count == 2
def test_reconnect_closes_stale_client(self) -> None:
"""When ping fails, the old client should be closed before reconnecting."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
factory()
# Switch to fresh client
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails, should close stale client
factory()
mock_client_stale.close.assert_called_once()
def test_first_call_creates_connection(self) -> None:
"""First call should always create a new connection."""
mock_client = MagicMock()
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client = factory()
assert client is mock_client
mock_app.broker_connection.assert_called_once()

View File

@@ -70,10 +70,6 @@ backend = [
"lazy_imports==1.0.1",
"lxml==5.3.0",
"Mako==1.2.4",
# NOTE: Do not update without understanding the patching behavior in
# get_markitdown_converter in
# backend/onyx/file_processing/extract_file_text.py and what impacts
# updating might have on this behavior.
"markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2",
"mcp[cli]==1.26.0",
"msal==1.34.0",

View File

@@ -1,22 +0,0 @@
import { cn } from "@opal/utils";
import type { IconProps } from "@opal/types";
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 37 46"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className={cn(className, "text-[#33C19E] dark:text-white")}
{...props}
>
<title>Bifrost</title>
<path
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
fill="currentColor"
/>
</svg>
);
export default SvgBifrost;

View File

@@ -24,7 +24,6 @@ export { default as SvgAzure } from "@opal/icons/azure";
export { default as SvgBarChart } from "@opal/icons/bar-chart";
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
export { default as SvgBell } from "@opal/icons/bell";
export { default as SvgBifrost } from "@opal/icons/bifrost";
export { default as SvgBlocks } from "@opal/icons/blocks";
export { default as SvgBookOpen } from "@opal/icons/book-open";
export { default as SvgBookmark } from "@opal/icons/bookmark";

View File

@@ -31,6 +31,7 @@ import { Credential } from "@/lib/connectors/credentials";
import { SettingsContext } from "@/providers/SettingsProvider";
import SourceTile from "@/components/SourceTile";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { ADMIN_ROUTES } from "@/lib/admin-routes";

View File

@@ -4,7 +4,7 @@ import { createApiKey, updateApiKey } from "./lib";
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { FormikField } from "@/refresh-components/form/FormikField";
@@ -79,7 +79,7 @@ export default function OnyxApiKeyForm({
{({ isSubmitting }) => (
<Form className="w-full overflow-visible">
<Modal.Body>
<Text as="p">
<Text as="p" color="text-05">
Choose a memorable name for your API key. This is optional and
can be added or changed later!
</Text>

View File

@@ -29,12 +29,10 @@ import {
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
import Message from "@/refresh-components/messages/Message";
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
import { useBillingInformation } from "@/hooks/useBillingInformation";
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
const route = ADMIN_ROUTES.API_KEYS;
@@ -47,11 +45,6 @@ function Main() {
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
const canCreateKeys = useCloudSubscription();
const { data: billingData } = useBillingInformation();
const isTrialing =
billingData !== undefined &&
hasActiveSubscription(billingData) &&
billingData.status === BillingStatus.TRIALING;
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
@@ -83,16 +76,6 @@ function Main() {
const introSection = (
<div className="flex flex-col items-start gap-4">
{isTrialing && (
<Message
static
warning
close={false}
className="w-full"
text="Upgrade to a paid plan to create API keys."
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
)}
<Text as="p">
API Keys allow you to access Onyx APIs programmatically.
{canCreateKeys
@@ -103,9 +86,23 @@ function Main() {
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
Create API Key
</CreateButton>
) : isTrialing ? (
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
) : null}
) : (
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
<div className="flex items-center gap-1.5">
<Text as="p" text04>
Upgrade to a paid plan to create API keys.
</Text>
<Button
variant="none"
prominence="tertiary"
size="2xs"
icon={SvgInfo}
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
</div>
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
</div>
)}
</div>
);

View File

@@ -9,6 +9,7 @@ import Card from "@/refresh-components/cards/Card";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { Disabled } from "@opal/core";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Message from "@/refresh-components/messages/Message";
import InfoBlock from "@/refresh-components/messages/InfoBlock";

View File

@@ -5,6 +5,7 @@ import { Section } from "@/layouts/general-layouts";
import * as InputLayouts from "@/layouts/input-layouts";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Card from "@/refresh-components/cards/Card";
import Separator from "@/refresh-components/Separator";

View File

@@ -4,6 +4,7 @@ import { useState } from "react";
import Card from "@/refresh-components/cards/Card";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import InputFile from "@/refresh-components/inputs/InputFile";
import { Section } from "@/layouts/general-layouts";

View File

@@ -22,6 +22,7 @@ import type { IconProps } from "@opal/types";
import Card from "@/refresh-components/cards/Card";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Section } from "@/layouts/general-layouts";

View File

@@ -1,387 +0,0 @@
/**
* Tests for BillingPage handleBillingReturn retry logic.
*
* The retry logic retries claimLicense up to 3 times with 2s backoff
* when returning from a Stripe checkout session. This prevents the user
* from getting stranded when the Stripe webhook fires concurrently with
* the browser redirect and the license isn't ready yet.
*/
import React from "react";
import { render, screen, waitFor } from "@tests/setup/test-utils";
import { act } from "@testing-library/react";
// ---- Stable mock objects (must be named with mock* prefix for jest hoisting) ----
// useRouter and useSearchParams must return the SAME reference each call, otherwise
// React's useEffect sees them as changed and re-runs the effect on every render.
const mockRouter = {
replace: jest.fn() as jest.Mock,
refresh: jest.fn() as jest.Mock,
};
const mockSearchParams = {
get: jest.fn() as jest.Mock,
};
const mockClaimLicense = jest.fn() as jest.Mock;
const mockRefreshBilling = jest.fn() as jest.Mock;
const mockRefreshLicense = jest.fn() as jest.Mock;
// ---- Mocks ----
jest.mock("next/navigation", () => ({
useRouter: () => mockRouter,
useSearchParams: () => mockSearchParams,
}));
jest.mock("@/layouts/settings-layouts", () => ({
Root: ({ children }: { children: React.ReactNode }) => (
<div data-testid="settings-root">{children}</div>
),
Header: () => <div data-testid="settings-header" />,
Body: ({ children }: { children: React.ReactNode }) => (
<div data-testid="settings-body">{children}</div>
),
}));
jest.mock("@/layouts/general-layouts", () => ({
Section: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
jest.mock("@opal/icons", () => ({
SvgArrowUpCircle: () => <svg />,
SvgWallet: () => <svg />,
}));
jest.mock("./PlansView", () => ({
__esModule: true,
default: () => <div data-testid="plans-view" />,
}));
jest.mock("./CheckoutView", () => ({
__esModule: true,
default: () => <div data-testid="checkout-view" />,
}));
jest.mock("./BillingDetailsView", () => ({
__esModule: true,
default: () => <div data-testid="billing-details-view" />,
}));
jest.mock("./LicenseActivationCard", () => ({
__esModule: true,
default: () => <div data-testid="license-activation-card" />,
}));
jest.mock("@/refresh-components/messages/Message", () => ({
__esModule: true,
default: ({
text,
description,
onClose,
}: {
text: string;
description?: string;
onClose?: () => void;
}) => (
<div data-testid="activating-banner">
<span data-testid="activating-banner-text">{text}</span>
{description && (
<span data-testid="activating-banner-description">{description}</span>
)}
{onClose && (
<button data-testid="activating-banner-close" onClick={onClose}>
Close
</button>
)}
</div>
),
}));
jest.mock("@/lib/billing", () => ({
useBillingInformation: jest.fn(),
useLicense: jest.fn(),
hasActiveSubscription: jest.fn().mockReturnValue(false),
claimLicense: (...args: unknown[]) => mockClaimLicense(...args),
}));
jest.mock("@/lib/constants", () => ({
NEXT_PUBLIC_CLOUD_ENABLED: false,
}));
// ---- Import after mocks ----
import BillingPage from "./page";
import { useBillingInformation, useLicense } from "@/lib/billing";
// ---- Test helpers ----
function setupHooks() {
(useBillingInformation as jest.Mock).mockReturnValue({
data: null,
isLoading: false,
error: null,
refresh: mockRefreshBilling,
});
(useLicense as jest.Mock).mockReturnValue({
data: null,
isLoading: false,
refresh: mockRefreshLicense,
});
}
// ---- Tests ----
describe("BillingPage — handleBillingReturn retry logic", () => {
beforeEach(() => {
jest.clearAllMocks();
jest.useFakeTimers();
setupHooks();
// Default: no billing-return params
mockSearchParams.get.mockReturnValue(null);
// Clear any activating state from prior tests
sessionStorage.clear();
});
afterEach(() => {
jest.useRealTimers();
jest.restoreAllMocks();
});
test("calls claimLicense once and refreshes on first-attempt success", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_test_123" : null
);
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
expect(mockClaimLicense).toHaveBeenCalledWith("cs_test_123");
});
expect(mockRouter.refresh).toHaveBeenCalled();
expect(mockRefreshBilling).toHaveBeenCalled();
// URL cleaned up after checkout return
expect(mockRouter.replace).toHaveBeenCalledWith("/admin/billing", {
scroll: false,
});
});
test("retries after first failure and succeeds on second attempt", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_retry_test" : null
);
mockClaimLicense
.mockRejectedValueOnce(new Error("License not ready yet"))
.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(2);
});
// On eventual success, router and billing should be refreshed
expect(mockRouter.refresh).toHaveBeenCalled();
expect(mockRefreshBilling).toHaveBeenCalled();
});
test("retries all 3 times then navigates to details even on total failure", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_all_fail" : null
);
// All 3 attempts fail
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(3);
});
// User stays on plans view with the activating banner
await waitFor(() => {
expect(screen.getByTestId("plans-view")).toBeInTheDocument();
});
// refreshBilling still fires so billing state is up to date
expect(mockRefreshBilling).toHaveBeenCalled();
// Failure is logged
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("Failed to sync license after billing return"),
expect.any(Error)
);
consoleSpy.mockRestore();
});
test("calls claimLicense without session_id on portal_return", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "portal_return" ? "true" : null
);
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
// No session_id for portal returns — called with undefined
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
});
expect(mockRefreshBilling).toHaveBeenCalled();
});
test("does not call claimLicense when no billing-return params present", async () => {
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(mockClaimLicense).not.toHaveBeenCalled();
});
test("shows activating banner and sets sessionStorage on 3x retry failure", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_all_fail" : null
);
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
});
expect(screen.getByTestId("activating-banner-text")).toHaveTextContent(
"Your license is still activating"
);
expect(
sessionStorage.getItem("billing_license_activating_until")
).not.toBeNull();
consoleSpy.mockRestore();
});
test("banner not rendered when no activating state", async () => {
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
});
test("banner shown on mount when sessionStorage key is set and not expired", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
// Flush React effects — banner is visible from lazy state init, no timer advancement needed
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
});
test("banner not shown on mount when sessionStorage key is expired", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() - 1000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
});
test("poll calls claimLicense after 15s and clears banner on success", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
// Poll attempt succeeds
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
// Flush effects — banner visible from lazy state init
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
// Advance past one poll interval (15s)
await act(async () => {
await jest.advanceTimersByTimeAsync(15_000);
});
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
expect(mockRefreshBilling).toHaveBeenCalled();
expect(mockRefreshLicense).toHaveBeenCalled();
expect(mockRouter.refresh).toHaveBeenCalled();
});
test("close button removes banner and clears sessionStorage", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
// Flush effects — banner visible from lazy state init
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
const closeButton = screen.getByTestId("activating-banner-close");
await act(async () => {
closeButton.click();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
});
});

View File

@@ -5,6 +5,7 @@ import { useSearchParams, useRouter } from "next/navigation";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Section } from "@/layouts/general-layouts";
import Button from "@/refresh-components/buttons/Button";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgArrowUpCircle, SvgWallet } from "@opal/icons";
import type { IconProps } from "@opal/types";
@@ -17,7 +18,6 @@ import {
} from "@/lib/billing";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { useUser } from "@/providers/UserProvider";
import Message from "@/refresh-components/messages/Message";
import PlansView from "./PlansView";
import CheckoutView from "./CheckoutView";
@@ -25,9 +25,6 @@ import BillingDetailsView from "./BillingDetailsView";
import LicenseActivationCard from "./LicenseActivationCard";
import "./billing.css";
// sessionStorage key: value is a unix-ms expiry timestamp
const BILLING_ACTIVATING_KEY = "billing_license_activating_until";
// ----------------------------------------------------------------------------
// Types
// ----------------------------------------------------------------------------
@@ -109,7 +106,6 @@ export default function BillingPage() {
const [transitionType, setTransitionType] = useState<
"expand" | "collapse" | "fade"
>("fade");
const [isActivating, setIsActivating] = useState<boolean>(false);
const {
data: billingData,
@@ -160,17 +156,6 @@ export default function BillingPage() {
view,
]);
// Read activating state from sessionStorage after mount (avoids SSR hydration mismatch)
useEffect(() => {
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
if (!raw) return;
if (Number(raw) > Date.now()) {
setIsActivating(true);
} else {
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
}
}, []);
// Show license activation card when there's a Stripe error
useEffect(() => {
if (hasStripeError && !showLicenseActivationInput) {
@@ -188,96 +173,24 @@ export default function BillingPage() {
router.replace("/admin/billing", { scroll: false });
let cancelled = false;
const handleBillingReturn = async () => {
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
// Retry up to 3 times with 2s backoff. The license may not be available
// immediately if the Stripe webhook hasn't finished processing yet
// (redirect and webhook fire nearly simultaneously).
let lastError: Error | null = null;
for (let attempt = 0; attempt < 3; attempt++) {
if (cancelled) return;
try {
// After checkout, exchange session_id for license; after portal, re-sync license
await claimLicense(sessionId ?? undefined);
if (cancelled) return;
refreshLicense();
// Refresh the page to update settings (including ee_features_enabled)
router.refresh();
// Navigate to billing details now that the license is active
changeView("details");
lastError = null;
break;
} catch (err) {
lastError = err instanceof Error ? err : new Error("Unknown error");
if (attempt < 2) {
await new Promise((resolve) => setTimeout(resolve, 2000));
}
}
}
if (cancelled) return;
if (lastError) {
console.error(
"Failed to sync license after billing return:",
lastError
);
// Show an activating banner on the plans view and keep retrying in the background.
sessionStorage.setItem(
BILLING_ACTIVATING_KEY,
String(Date.now() + 120_000)
);
setIsActivating(true);
changeView("plans");
try {
// After checkout, exchange session_id for license; after portal, re-sync license
await claimLicense(sessionId ?? undefined);
refreshLicense();
// Refresh the page to update settings (including ee_features_enabled)
router.refresh();
// Navigate to billing details now that the license is active
changeView("details");
} catch (error) {
console.error("Failed to sync license after billing return:", error);
}
}
if (!cancelled) refreshBilling();
refreshBilling();
};
handleBillingReturn();
return () => {
cancelled = true;
};
// changeView intentionally omitted: it only calls stable state setters and the
// effect runs at most once (when session_id/portal_return params are present).
}, [searchParams, router, refreshBilling, refreshLicense]); // eslint-disable-line react-hooks/exhaustive-deps
// Poll every 15s while activating, up to 2 minutes, to detect when the license arrives.
useEffect(() => {
if (!isActivating) return;
let requestInFlight = false;
const intervalId = setInterval(async () => {
if (requestInFlight) return;
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
if (!raw || Number(raw) <= Date.now()) {
// Expired — stop immediately without waiting for React cleanup
clearInterval(intervalId);
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
return;
}
requestInFlight = true;
try {
await claimLicense(undefined);
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
refreshLicense();
refreshBilling();
router.refresh();
changeView("details");
} catch (err) {
// License not ready yet — keep polling. Log so unexpected failures
// (network errors, 500s) are distinguishable from expected 404s.
console.debug("License activation poll: will retry", err);
} finally {
requestInFlight = false;
}
}, 15_000);
return () => clearInterval(intervalId);
}, [isActivating]); // eslint-disable-line react-hooks/exhaustive-deps
}, [searchParams, router, refreshBilling, refreshLicense]);
const handleRefresh = async () => {
await Promise.all([
@@ -474,22 +387,6 @@ export default function BillingPage() {
/>
<SettingsLayouts.Body>
<div className="flex flex-col items-center gap-6">
{isActivating && (
<Message
static
warning
large
text="Your license is still activating"
description="Your license is being processed. You'll be taken to billing details automatically once confirmed."
icon
close
onClose={() => {
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
}}
className="w-full"
/>
)}
{renderContent()}
{renderFooter()}
</div>

View File

@@ -7,6 +7,7 @@ import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import useSWR from "swr";
import { ThreeDotsLoader } from "@/components/Loading";
import * as SettingsLayouts from "@/layouts/settings-layouts";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { SvgLock } from "@opal/icons";

View File

@@ -1,11 +1,11 @@
"use client";
import { useState, useMemo, useEffect } from "react";
import { useState, useMemo } from "react";
import useSWR from "swr";
import { Text } from "@opal/components";
import { Select } from "@/refresh-components/cards";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
import {
@@ -17,17 +17,9 @@ import {
ImageGenerationConfigView,
setDefaultImageGenerationConfig,
unsetDefaultImageGenerationConfig,
deleteImageGenerationConfig,
} from "@/lib/configuration/imageConfigurationService";
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
import Message from "@/refresh-components/messages/Message";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { Button, Text } from "@opal/components";
import { SvgSlash, SvgUnplug } from "@opal/icons";
import { markdown } from "@opal/utils";
const NO_DEFAULT_VALUE = "__none__";
export default function ImageGenerationContent() {
const {
@@ -55,11 +47,6 @@ export default function ImageGenerationContent() {
);
const [editConfig, setEditConfig] =
useState<ImageGenerationConfigView | null>(null);
const [disconnectProvider, setDisconnectProvider] =
useState<ImageProvider | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const connectedProviderIds = useMemo(() => {
return new Set(configs.map((c) => c.image_provider_id));
@@ -128,29 +115,6 @@ export default function ImageGenerationContent() {
modal.toggle(true);
};
const handleDisconnect = async () => {
if (!disconnectProvider) return;
try {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
await setDefaultImageGenerationConfig(replacementProviderId);
}
await deleteImageGenerationConfig(disconnectProvider.image_provider_id);
toast.success(`${disconnectProvider.title} disconnected`);
refetchConfigs();
refetchProviders();
} catch (error) {
console.error("Failed to disconnect image generation provider:", error);
toast.error(
error instanceof Error ? error.message : "Failed to disconnect"
);
} finally {
setDisconnectProvider(null);
setReplacementProviderId(null);
}
};
const handleModalSuccess = () => {
toast.success("Provider configured successfully");
setEditConfig(null);
@@ -166,44 +130,12 @@ export default function ImageGenerationContent() {
);
}
// Compute replacement options when disconnecting an active provider
const isDisconnectingDefault =
disconnectProvider &&
defaultConfig?.image_provider_id === disconnectProvider.image_provider_id;
// Group connected replacement models by provider (excluding the model being disconnected)
const replacementGroups = useMemo(() => {
if (!disconnectProvider) return [];
return IMAGE_PROVIDER_GROUPS.map((group) => ({
...group,
providers: group.providers.filter(
(p) =>
p.image_provider_id !== disconnectProvider.image_provider_id &&
connectedProviderIds.has(p.image_provider_id)
),
})).filter((g) => g.providers.length > 0);
}, [disconnectProvider, connectedProviderIds]);
const needsReplacement = !!isDisconnectingDefault;
const hasReplacements = replacementGroups.length > 0;
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && !replacementProviderId && hasReplacements) {
const firstGroup = replacementGroups[0];
const firstModel = firstGroup?.providers[0];
if (firstModel) setReplacementProviderId(firstModel.image_provider_id);
}
}, [disconnectProvider]); // eslint-disable-line react-hooks/exhaustive-deps
return (
<>
<div className="flex flex-col gap-6">
{/* Section Header */}
<div className="flex flex-col gap-0.5">
<Text font="main-content-emphasis" color="text-05">
Image Generation Model
</Text>
<Text font="main-content-emphasis">Image Generation Model</Text>
<Text font="secondary-body" color="text-03">
Select a model to generate images in chat.
</Text>
@@ -241,11 +173,6 @@ export default function ImageGenerationContent() {
onSelect={() => handleSelect(provider)}
onDeselect={() => handleDeselect(provider)}
onEdit={() => handleEdit(provider)}
onDisconnect={
getStatus(provider) !== "disconnected"
? () => setDisconnectProvider(provider)
: undefined
}
/>
))}
</div>
@@ -253,108 +180,6 @@ export default function ImageGenerationContent() {
))}
</div>
{disconnectProvider && (
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectProvider.title}`}
description="This will remove the stored credentials for this provider."
onClose={() => {
setDisconnectProvider(null);
setReplacementProviderId(null);
}}
submit={
<Button
variant="danger"
onClick={() => void handleDisconnect()}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</Button>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" color="text-03">
{markdown(
`**${disconnectProvider.title}** is currently the default image generation model. Session history will be preserved.`
)}
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" color="text-04">
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => setReplacementProviderId(v)}
>
<InputSelect.Trigger placeholder="Select a replacement model" />
<InputSelect.Content>
{replacementGroups.map((group) => (
<InputSelect.Group key={group.name}>
<InputSelect.Label>{group.name}</InputSelect.Label>
{group.providers.map((p) => (
<InputSelect.Item
key={p.image_provider_id}
value={p.image_provider_id}
icon={() => (
<ProviderIcon
provider={p.provider_name}
size={16}
/>
)}
>
{p.title}
</InputSelect.Item>
))}
</InputSelect.Group>
))}
<InputSelect.Separator />
<InputSelect.Item
value={NO_DEFAULT_VALUE}
icon={SvgSlash}
>
<span>
<b>No Default</b>
<span className="text-text-03">
{" "}
(Disable Image Generation)
</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" color="text-03">
{markdown(
`**${disconnectProvider.title}** is currently the default image generation model.`
)}
</Text>
<Text as="p" color="text-03">
Connect another provider to continue using image generation.
</Text>
</>
)
) : (
<>
<Text as="p" color="text-03">
{markdown(
`**${disconnectProvider.title}** models will no longer be used to generate images.`
)}
</Text>
<Text as="p" color="text-03">
Session history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
)}
{activeProvider && (
<modal.Provider>
<ImageGenerationConnectionModal

View File

@@ -8,6 +8,7 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { SvgX } from "@opal/icons";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
function ModelConfigurationRow({

View File

@@ -1 +1,7 @@
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";
"use client";
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
export default function Page() {
return <LLMConfigurationPage />;
}

View File

@@ -23,7 +23,6 @@ import {
BedrockModelResponse,
LMStudioModelResponse,
LiteLLMProxyModelResponse,
BifrostModelResponse,
ModelConfiguration,
LLMProviderName,
BedrockFetchParams,
@@ -31,11 +30,8 @@ import {
LMStudioFetchParams,
OpenRouterFetchParams,
LiteLLMProxyFetchParams,
BifrostFetchParams,
OpenAICompatibleFetchParams,
OpenAICompatibleModelResponse,
} from "@/interfaces/llm";
import { SvgAws, SvgBifrost, SvgOpenrouter, SvgPlug } from "@opal/icons";
import { SvgAws, SvgOpenrouter } from "@opal/icons";
// Aggregator providers that host models from multiple vendors
export const AGGREGATOR_PROVIDERS = new Set([
@@ -45,8 +41,6 @@ export const AGGREGATOR_PROVIDERS = new Set([
"ollama_chat",
"lm_studio",
"litellm_proxy",
"bifrost",
"openai_compatible",
"vertex_ai",
]);
@@ -84,8 +78,6 @@ export const getProviderIcon = (
bedrock_converse: SvgAws,
openrouter: SvgOpenrouter,
litellm_proxy: LiteLLMIcon,
bifrost: SvgBifrost,
openai_compatible: SvgPlug,
vertex_ai: GeminiIcon,
};
@@ -271,11 +263,8 @@ export const fetchOpenRouterModels = async (
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch (jsonError) {
console.warn(
"Failed to parse OpenRouter model fetch error response",
jsonError
);
} catch {
// ignore JSON parsing errors
}
return { models: [], error: errorMessage };
}
@@ -325,125 +314,6 @@ export const fetchLMStudioModels = async (
signal: params.signal,
});
if (!response.ok) {
let errorMessage = "Failed to fetch models";
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch (jsonError) {
console.warn(
"Failed to parse LM Studio model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
const data: LMStudioModelResponse[] = await response.json();
const models: ModelConfiguration[] = data.map((modelData) => ({
name: modelData.name,
display_name: modelData.display_name,
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: modelData.supports_reasoning,
}));
return { models };
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
return { models: [], error: errorMessage };
}
};
/**
* Fetches Bifrost models directly without any form state dependencies.
* Uses snake_case params to match API structure.
*/
export const fetchBifrostModels = async (
params: BifrostFetchParams
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
const apiBase = params.api_base;
if (!apiBase) {
return { models: [], error: "API Base is required" };
}
try {
const response = await fetch("/api/admin/llm/bifrost/available-models", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_base: apiBase,
api_key: params.api_key,
provider_name: params.provider_name,
}),
signal: params.signal,
});
if (!response.ok) {
let errorMessage = "Failed to fetch models";
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch (jsonError) {
console.warn(
"Failed to parse Bifrost model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
const data: BifrostModelResponse[] = await response.json();
const models: ModelConfiguration[] = data.map((modelData) => ({
name: modelData.name,
display_name: modelData.display_name,
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: modelData.supports_reasoning,
}));
return { models };
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
return { models: [], error: errorMessage };
}
};
/**
* Fetches models from a generic OpenAI-compatible server.
* Uses snake_case params to match API structure.
*/
export const fetchOpenAICompatibleModels = async (
params: OpenAICompatibleFetchParams
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
const apiBase = params.api_base;
if (!apiBase) {
return { models: [], error: "API Base is required" };
}
try {
const response = await fetch(
"/api/admin/llm/openai-compatible/available-models",
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_base: apiBase,
api_key: params.api_key,
provider_name: params.provider_name,
}),
signal: params.signal,
}
);
if (!response.ok) {
let errorMessage = "Failed to fetch models";
try {
@@ -455,7 +325,7 @@ export const fetchOpenAICompatibleModels = async (
return { models: [], error: errorMessage };
}
const data: OpenAICompatibleModelResponse[] = await response.json();
const data: LMStudioModelResponse[] = await response.json();
const models: ModelConfiguration[] = data.map((modelData) => ({
name: modelData.name,
display_name: modelData.display_name,
@@ -586,20 +456,6 @@ export const fetchModels = async (
provider_name: formValues.name,
signal,
});
case LLMProviderName.BIFROST:
return fetchBifrostModels({
api_base: formValues.api_base,
api_key: formValues.api_key,
provider_name: formValues.name,
signal,
});
case LLMProviderName.OPENAI_COMPATIBLE:
return fetchOpenAICompatibleModels({
api_base: formValues.api_base,
api_key: formValues.api_key,
provider_name: formValues.name,
signal,
});
default:
return { models: [], error: `Unknown provider: ${providerName}` };
}
@@ -613,8 +469,6 @@ export function canProviderFetchModels(providerName?: string) {
case LLMProviderName.LM_STUDIO:
case LLMProviderName.OPENROUTER:
case LLMProviderName.LITELLM_PROXY:
case LLMProviderName.BIFROST:
case LLMProviderName.OPENAI_COMPATIBLE:
return true;
default:
return false;

View File

@@ -1,25 +1,33 @@
"use client";
import Image from "next/image";
import { useEffect, useMemo, useState, useReducer } from "react";
import { useMemo, useState, useReducer } from "react";
import { InfoIcon } from "@/components/icons/icons";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Select } from "@/refresh-components/cards";
import { Section } from "@/layouts/general-layouts";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Content } from "@opal/layouts";
import useSWR from "swr";
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
import { ThreeDotsLoader } from "@/components/Loading";
import { Callout } from "@/components/ui/callout";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { Disabled } from "@opal/core";
import { cn } from "@/lib/utils";
import { toast } from "@/hooks/useToast";
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
import { Button } from "@opal/components";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgCheckSquare,
SvgEdit,
SvgGlobe,
SvgOnyxLogo,
SvgX,
} from "@opal/icons";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
const route = ADMIN_ROUTES.WEB_SEARCH;
import {
SEARCH_PROVIDERS_URL,
SEARCH_PROVIDER_DETAILS,
@@ -51,10 +59,6 @@ import {
} from "@/app/admin/configuration/web-search/WebProviderModalReducer";
import { connectProviderFlow } from "@/app/admin/configuration/web-search/connectProviderFlow";
const NO_DEFAULT_VALUE = "__none__";
const route = ADMIN_ROUTES.WEB_SEARCH;
interface WebSearchProviderView {
id: number;
name: string;
@@ -73,151 +77,27 @@ interface WebContentProviderView {
has_api_key: boolean;
}
interface DisconnectTargetState {
id: number;
label: string;
category: "search" | "content";
providerType: string;
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
isHovered: boolean;
onMouseEnter: () => void;
onMouseLeave: () => void;
children: React.ReactNode;
}
function WebSearchDisconnectModal({
disconnectTarget,
searchProviders,
contentProviders,
replacementProviderId,
onReplacementChange,
onClose,
onDisconnect,
}: {
disconnectTarget: DisconnectTargetState;
searchProviders: WebSearchProviderView[];
contentProviders: WebContentProviderView[];
replacementProviderId: string | null;
onReplacementChange: (id: string | null) => void;
onClose: () => void;
onDisconnect: () => void;
}) {
const isSearch = disconnectTarget.category === "search";
// Determine if the target is currently the active/selected provider
const isActive = isSearch
? searchProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
false
: contentProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
false;
// Find other configured providers as replacements
const replacementOptions = isSearch
? searchProviders.filter(
(p) => p.id !== disconnectTarget.id && p.id > 0 && p.has_api_key
)
: contentProviders.filter(
(p) =>
p.id !== disconnectTarget.id &&
p.provider_type !== "onyx_web_crawler" &&
p.id > 0 &&
p.has_api_key
);
const needsReplacement = isActive;
const hasReplacements = replacementOptions.length > 0;
const getLabel = (p: { name: string; provider_type: string }) => {
if (isSearch) {
const details =
SEARCH_PROVIDER_DETAILS[p.provider_type as WebSearchProviderType];
return details?.label ?? p.name ?? p.provider_type;
}
const details = CONTENT_PROVIDER_DETAILS[p.provider_type];
return details?.label ?? p.name ?? p.provider_type;
};
const categoryLabel = isSearch ? "search engine" : "web crawler";
const featureLabel = isSearch ? "web search" : "web crawling";
const disableLabel = isSearch ? "Disable Web Search" : "Disable Web Crawling";
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && hasReplacements && !replacementProviderId) {
const first = replacementOptions[0];
if (first) onReplacementChange(String(first.id));
}
}, []); // eslint-disable-line react-hooks/exhaustive-deps
function HoverIconButton({
isHovered,
onMouseEnter,
onMouseLeave,
children,
...buttonProps
}: HoverIconButtonProps) {
return (
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectTarget.label}`}
description="This will remove the stored credentials for this provider."
onClose={onClose}
submit={
<Button
variant="danger"
onClick={onDisconnect}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</Button>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" text03>
<b>{disconnectTarget.label}</b> is currently the active{" "}
{categoryLabel}. Search history will be preserved.
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" secondaryBody text03>
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => onReplacementChange(v)}
>
<InputSelect.Trigger placeholder="Select a replacement provider" />
<InputSelect.Content>
{replacementOptions.map((p) => (
<InputSelect.Item key={p.id} value={String(p.id)}>
{getLabel(p)}
</InputSelect.Item>
))}
<InputSelect.Separator />
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
<span>
<b>No Default</b>
<span className="text-text-03"> ({disableLabel})</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" text03>
<b>{disconnectTarget.label}</b> is currently the active{" "}
{categoryLabel}.
</Text>
<Text as="p" text03>
Connect another provider to continue using {featureLabel}.
</Text>
</>
)
) : (
<>
<Text as="p" text03>
{isSearch ? "Web search" : "Web crawling"} will no longer be routed
through <b>{disconnectTarget.label}</b>.
</Text>
<Text as="p" text03>
Search history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
{/* TODO(@raunakab): migrate to opal Button once HoverIconButtonProps typing is resolved */}
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
{children}
</Button>
</div>
);
}
@@ -226,11 +106,6 @@ export default function Page() {
WebProviderModalReducer,
initialWebProviderModalState
);
const [disconnectTarget, setDisconnectTarget] =
useState<DisconnectTargetState | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const [contentModal, dispatchContentModal] = useReducer(
WebProviderModalReducer,
initialWebProviderModalState
@@ -239,6 +114,8 @@ export default function Page() {
const [contentActivationError, setContentActivationError] = useState<
string | null
>(null);
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
const {
data: searchProvidersData,
error: searchProvidersError,
@@ -957,67 +834,6 @@ export default function Page() {
});
};
const handleDisconnectProvider = async () => {
if (!disconnectTarget) return;
const { id, category } = disconnectTarget;
try {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
const repId = Number(replacementProviderId);
const activateEndpoint =
category === "search"
? `/api/admin/web-search/search-providers/${repId}/activate`
: `/api/admin/web-search/content-providers/${repId}/activate`;
const activateResp = await fetch(activateEndpoint, {
method: "POST",
headers: { "Content-Type": "application/json" },
});
if (!activateResp.ok) {
const errorBody = await activateResp.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to activate replacement provider."
);
}
}
const response = await fetch(
`/api/admin/web-search/${category}-providers/${id}`,
{ method: "DELETE" }
);
if (!response.ok) {
const errorBody = await response.json().catch((parseErr) => {
console.error("Failed to parse disconnect error response:", parseErr);
return {};
});
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to disconnect provider."
);
}
toast.success(`${disconnectTarget.label} disconnected`);
await mutateSearchProviders();
await mutateContentProviders();
} catch (error) {
console.error("Failed to disconnect web search provider:", error);
const message =
error instanceof Error ? error.message : "Unexpected error occurred.";
if (category === "search") {
setActivationError(message);
} else {
setContentActivationError(message);
}
} finally {
setDisconnectTarget(null);
setReplacementProviderId(null);
}
};
return (
<>
<SettingsLayouts.Root>
@@ -1079,79 +895,149 @@ export default function Page() {
provider
);
const isActive = provider?.is_active ?? false;
const isHighlighted = isActive;
const providerId = provider?.id;
const canOpenModal =
isBuiltInSearchProviderType(providerType);
const status: "disconnected" | "connected" | "selected" =
!isConfigured
? "disconnected"
: isActive
? "selected"
: "connected";
return (
<Select
key={`${key}-${providerType}`}
icon={() =>
logoSrc ? (
<Image
src={logoSrc}
alt={`${label} logo`}
width={16}
height={16}
/>
) : (
<SvgGlobe size={16} />
)
}
title={label}
description={subtitle}
status={status}
onConnect={
canOpenModal
const buttonState = (() => {
if (!provider || !isConfigured) {
return {
label: "Connect",
disabled: false,
icon: "arrow" as const,
onClick: canOpenModal
? () => {
openSearchModal(providerType, provider);
setActivationError(null);
}
: undefined
}
onSelect={
providerId
? () => {
void handleActivateSearchProvider(providerId);
}
: undefined
}
onDeselect={
providerId
: undefined,
};
}
if (isActive) {
return {
label: "Current Default",
disabled: false,
icon: "check" as const,
onClick: providerId
? () => {
void handleDeactivateSearchProvider(providerId);
}
: undefined
}
onEdit={
isConfigured && canOpenModal
? () => {
: undefined,
};
}
return {
label: "Set as Default",
disabled: false,
icon: "arrow-circle" as const,
onClick: providerId
? () => {
void handleActivateSearchProvider(providerId);
}
: undefined,
};
})();
const buttonKey = `search-${key}-${providerType}`;
const isButtonHovered = hoveredButtonKey === buttonKey;
const isCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleCardClick = () => {
if (isCardClickable) {
buttonState.onClick?.();
}
};
return (
<div
key={`${key}-${providerType}`}
onClick={isCardClickable ? handleCardClick : undefined}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
isHighlighted
? "border-action-link-05"
: "border-border-01",
isCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 px-2 py-1">
{renderLogo({
logoSrc,
alt: `${label} logo`,
size: 16,
isHighlighted,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
</div>
<div className="flex items-center justify-end gap-2">
{isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={() => {
if (!canOpenModal) return;
openSearchModal(
providerType as WebSearchProviderType,
provider
);
}}
aria-label={`Edit ${label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isButtonHovered}
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<Disabled
disabled={
buttonState.disabled || !buttonState.onClick
}
: undefined
}
onDisconnect={
isConfigured && provider && provider.id > 0
? () =>
setDisconnectTarget({
id: provider.id,
label,
category: "search",
providerType,
})
: undefined
}
/>
>
<OpalButton
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</OpalButton>
</Disabled>
)}
</div>
</div>
);
}
)}
@@ -1191,81 +1077,161 @@ export default function Page() {
const isCurrentCrawler =
provider.provider_type === currentContentProviderType;
const status: "disconnected" | "connected" | "selected" =
!isConfigured
? "disconnected"
: isCurrentCrawler
? "selected"
: "connected";
const buttonState = (() => {
if (!isConfigured) {
return {
label: "Connect",
icon: "arrow" as const,
disabled: false,
onClick: () => {
openContentModal(provider.provider_type, provider);
setContentActivationError(null);
},
};
}
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
if (isCurrentCrawler) {
return {
label: "Current Crawler",
icon: "check" as const,
disabled: false,
onClick: () => {
void handleDeactivateContentProvider(
providerId,
provider.provider_type
);
},
};
}
const contentLogoSrc =
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.logoSrc;
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
return {
label: "Set as Default",
icon: "arrow-circle" as const,
disabled: !canActivate,
onClick: canActivate
? () => {
void handleActivateContentProvider(provider);
}
: undefined,
};
})();
const contentButtonKey = `content-${provider.provider_type}-${provider.id}`;
const isContentButtonHovered =
hoveredButtonKey === contentButtonKey;
const isContentCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleContentCardClick = () => {
if (isContentCardClickable) {
buttonState.onClick?.();
}
};
return (
<Select
<div
key={`${provider.provider_type}-${provider.id}`}
icon={() =>
contentLogoSrc ? (
<Image
src={contentLogoSrc}
alt={`${label} logo`}
width={16}
height={16}
/>
) : provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo size={16} />
onClick={
isContentCardClickable
? handleContentCardClick
: undefined
}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
isCurrentCrawler
? "border-action-link-05"
: "border-border-01",
isContentCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 px-2 py-1">
{renderLogo({
logoSrc:
CONTENT_PROVIDER_DETAILS[provider.provider_type]
?.logoSrc,
alt: `${label} logo`,
fallback:
provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo size={16} />
) : undefined,
size: 16,
isHighlighted: isCurrentCrawler,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
</div>
<div className="flex items-center justify-end gap-2">
{provider.provider_type !== "onyx_web_crawler" &&
isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={() => {
openContentModal(
provider.provider_type,
provider
);
}}
aria-label={`Edit ${label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isContentButtonHovered}
onMouseEnter={() =>
setHoveredButtonKey(contentButtonKey)
}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<SvgGlobe size={16} />
)
}
title={label}
description={subtitle}
status={status}
selectedLabel="Current Crawler"
onConnect={() => {
openContentModal(provider.provider_type, provider);
setContentActivationError(null);
}}
onSelect={
canActivate
? () => {
void handleActivateContentProvider(provider);
<Disabled
disabled={
buttonState.disabled || !buttonState.onClick
}
: undefined
}
onDeselect={() => {
void handleDeactivateContentProvider(
providerId,
provider.provider_type
);
}}
onEdit={
provider.provider_type !== "onyx_web_crawler" &&
isConfigured
? () => {
openContentModal(provider.provider_type, provider);
}
: undefined
}
onDisconnect={
provider.provider_type !== "onyx_web_crawler" &&
isConfigured &&
provider.id > 0
? () =>
setDisconnectTarget({
id: provider.id,
label,
category: "content",
providerType: provider.provider_type,
})
: undefined
}
/>
>
<OpalButton
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</OpalButton>
</Disabled>
)}
</div>
</div>
);
})}
</div>
@@ -1273,21 +1239,6 @@ export default function Page() {
</SettingsLayouts.Body>
</SettingsLayouts.Root>
{disconnectTarget && (
<WebSearchDisconnectModal
disconnectTarget={disconnectTarget}
searchProviders={searchProviders}
contentProviders={combinedContentProviders}
replacementProviderId={replacementProviderId}
onReplacementChange={setReplacementProviderId}
onClose={() => {
setDisconnectTarget(null);
setReplacementProviderId(null);
}}
onDisconnect={() => void handleDisconnectProvider()}
/>
)}
<WebProviderSetupModal
isOpen={selectedProviderType !== null}
onClose={() => {

View File

@@ -4,6 +4,7 @@ import { useState } from "react";
import { ValidSources } from "@/lib/types";
import { Section } from "@/layouts/general-layouts";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Button } from "@opal/components";
import Separator from "@/refresh-components/Separator";

View File

@@ -10,7 +10,7 @@ import {
import { IndexAttemptError } from "./types";
import { localizeAndPrettify } from "@/lib/time";
import Button from "@/refresh-components/buttons/Button";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import { PageSelector } from "@/components/PageSelector";
import { useCallback, useEffect, useRef, useState, useMemo } from "react";
import { SvgAlertTriangle } from "@opal/icons";
@@ -113,11 +113,11 @@ export default function IndexAttemptErrorsModal({
<Modal.Body height="full">
{!isResolvingErrors && (
<div className="flex flex-col gap-2 flex-shrink-0">
<Text as="p">
<Text as="p" color="text-05">
Below are the errors encountered during indexing. Each row
represents a failed document or entity.
</Text>
<Text as="p">
<Text as="p" color="text-05">
Click the button below to kick off a full re-index to try and
resolve these errors. This full re-index may take much longer
than a normal update.

View File

@@ -21,6 +21,7 @@ import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ThreeDotsLoader } from "@/components/Loading";
import Modal from "@/refresh-components/Modal";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import {
SvgCheck,

View File

@@ -6,6 +6,7 @@ import { useState } from "react";
import { toast } from "@/hooks/useToast";
import { triggerIndexing } from "@/app/admin/connector/[ccPairId]/lib";
import Modal from "@/refresh-components/Modal";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Separator from "@/refresh-components/Separator";
import { SvgRefreshCw } from "@opal/icons";

View File

@@ -7,6 +7,7 @@ import { SourceIcon } from "@/components/SourceIcon";
import { CCPairStatus, PermissionSyncStatus } from "@/components/Status";
import { toast } from "@/hooks/useToast";
import CredentialSection from "@/components/credentials/CredentialSection";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import {
updateConnectorCredentialPairName,

View File

@@ -59,6 +59,7 @@ import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { deleteConnector } from "@/lib/connector";
import ConnectorDocsLink from "@/components/admin/connectors/ConnectorDocsLink";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgKey, SvgAlertCircle } from "@opal/icons";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";

View File

@@ -19,7 +19,6 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
import { Credential } from "@/lib/connectors/credentials";
import { useFederatedConnectors } from "@/lib/hooks";
import Text from "@/refresh-components/texts/Text";
import { useToastFromQuery } from "@/hooks/useToast";
export default function ConnectorWrapper({

View File

@@ -15,7 +15,7 @@ import * as InputLayouts from "@/layouts/input-layouts";
import { Content } from "@opal/layouts";
import CheckboxField from "@/refresh-components/form/LabeledCheckboxField";
import InputTextAreaField from "@/refresh-components/form/InputTextAreaField";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
// Define a general type for form values
type FormValues = Record<string, any>;
@@ -57,7 +57,7 @@ const TabsField: FC<TabsFieldProps> = ({
{/* Ensure there's at least one tab before rendering */}
{tabField.tabs.length === 0 ? (
<Text text03 secondaryBody>
<Text color="text-03" font="secondary-body">
No tabs to display.
</Text>
) : (
@@ -253,7 +253,7 @@ export const RenderField: FC<RenderFieldProps> = ({
)
) : field.type === "string_tab" ? (
<GeneralLayouts.Section>
<Text text03 secondaryBody>
<Text color="text-03" font="secondary-body">
{description}
</Text>
</GeneralLayouts.Section>

View File

@@ -2,6 +2,7 @@
import { useState } from "react";
import { Section } from "@/layouts/general-layouts";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Card from "@/refresh-components/cards/Card";
import { Button } from "@opal/components";

View File

@@ -11,7 +11,7 @@ import {
import Switch from "@/refresh-components/inputs/Switch";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import EmptyMessage from "@/refresh-components/EmptyMessage";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import { Section } from "@/layouts/general-layouts";
import {
DiscordChannelConfig,
@@ -95,9 +95,7 @@ export function DiscordChannelsTable({
width="fit"
>
<ChannelIcon width={16} height={16} />
<Text text04 mainUiBody>
{channel.channel_name}
</Text>
<Text>{channel.channel_name}</Text>
</Section>
</TableCell>
<TableCell>

View File

@@ -8,11 +8,10 @@ import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import { ContentAction } from "@opal/layouts";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/refresh-components/texts/Text";
import { Button, Text } from "@opal/components";
import Card from "@/refresh-components/cards/Card";
import { Callout } from "@/components/ui/callout";
import Message from "@/refresh-components/messages/Message";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { SvgServer } from "@opal/icons";
import InputSelect from "@/refresh-components/inputs/InputSelect";
@@ -121,7 +120,7 @@ function GuildDetailContent({
/>
{!isRegistered ? (
<Text text03 secondaryBody>
<Text color="text-03" font="secondary-body">
Channel configuration will be available after the server is
registered.
</Text>

View File

@@ -6,7 +6,7 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import Modal from "@/refresh-components/Modal";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
@@ -76,7 +76,7 @@ function DiscordBotContent() {
description="This key will only be shown once!"
/>
<Modal.Body>
<Text text04 mainUiBody>
<Text>
Copy the command and send it from any text channel in your server!
</Text>
<Card variant="secondary">
@@ -85,8 +85,8 @@ function DiscordBotContent() {
justifyContent="between"
alignItems="center"
>
<Text text03 secondaryMono>
!register {registrationKey}
<Text color="text-03" font="secondary-mono">
{`!register ${registrationKey}`}
</Text>
<CopyIconButton
getCopyText={() => `!register ${registrationKey}`}
@@ -103,7 +103,7 @@ function DiscordBotContent() {
justifyContent="between"
alignItems="center"
>
<Text mainContentEmphasis text05>
<Text font="main-content-emphasis" color="text-05">
Server Configurations
</Text>
<CreateButton

View File

@@ -9,7 +9,7 @@ const route = ADMIN_ROUTES.INDEX_MIGRATION;
import Card from "@/refresh-components/cards/Card";
import { Content, ContentAction } from "@opal/layouts";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import Button from "@/refresh-components/buttons/Button";
import { errorHandlingFetcher } from "@/lib/fetcher";
@@ -38,10 +38,10 @@ function MigrationStatusSection() {
if (isLoading) {
return (
<Card>
<Text headingH3>Migration Status</Text>
<Text mainUiBody text03>
Loading...
<Text font="heading-h3" color="text-05">
Migration Status
</Text>
<Text color="text-03">Loading...</Text>
</Card>
);
}
@@ -49,10 +49,10 @@ function MigrationStatusSection() {
if (error) {
return (
<Card>
<Text headingH3>Migration Status</Text>
<Text mainUiBody text03>
Failed to load migration status.
<Text font="heading-h3" color="text-05">
Migration Status
</Text>
<Text color="text-03">Failed to load migration status.</Text>
</Card>
);
}
@@ -73,14 +73,16 @@ function MigrationStatusSection() {
return (
<Card>
<Text headingH3>Migration Status</Text>
<Text font="heading-h3" color="text-05">
Migration Status
</Text>
<ContentAction
title="Started"
sizePreset="main-ui"
variant="section"
rightChildren={
<Text mainUiBody>
<Text color="text-05">
{hasStarted ? formatTimestamp(data.created_at!) : "Not started"}
</Text>
}
@@ -91,7 +93,7 @@ function MigrationStatusSection() {
sizePreset="main-ui"
variant="section"
rightChildren={
<Text mainUiBody>
<Text color="text-05">
{progressPercentage !== null
? `${totalChunksMigrated} (approx. progress ${Math.round(
progressPercentage
@@ -106,7 +108,7 @@ function MigrationStatusSection() {
sizePreset="main-ui"
variant="section"
rightChildren={
<Text mainUiBody>
<Text color="text-05">
{hasCompleted
? formatTimestamp(data.migration_completed_at!)
: hasStarted
@@ -159,10 +161,10 @@ function RetrievalSourceSection() {
if (isLoading) {
return (
<Card>
<Text headingH3>Retrieval Source</Text>
<Text mainUiBody text03>
Loading...
<Text font="heading-h3" color="text-05">
Retrieval Source
</Text>
<Text color="text-03">Loading...</Text>
</Card>
);
}
@@ -170,10 +172,10 @@ function RetrievalSourceSection() {
if (error) {
return (
<Card>
<Text headingH3>Retrieval Source</Text>
<Text mainUiBody text03>
Failed to load retrieval settings.
<Text font="heading-h3" color="text-05">
Retrieval Source
</Text>
<Text color="text-03">Failed to load retrieval settings.</Text>
</Card>
);
}

View File

@@ -3,6 +3,7 @@
import React, { useRef, useState } from "react";
import Modal from "@/refresh-components/Modal";
import { Callout } from "@/components/ui/callout";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Separator from "@/refresh-components/Separator";
import Button from "@/refresh-components/buttons/Button";

View File

@@ -1,4 +1,5 @@
import Modal from "@/refresh-components/Modal";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Button } from "@opal/components";
import { Callout } from "@/components/ui/callout";

View File

@@ -1,5 +1,6 @@
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgAlertTriangle } from "@opal/icons";
export interface InstantSwitchConfirmModalProps {

View File

@@ -1,4 +1,5 @@
import Modal from "@/refresh-components/Modal";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Callout } from "@/components/ui/callout";
import { Button } from "@opal/components";

View File

@@ -1,4 +1,5 @@
import React, { useRef, useState } from "react";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Callout } from "@/components/ui/callout";
import { Button } from "@opal/components";

View File

@@ -1,5 +1,6 @@
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { CloudEmbeddingModel } from "@/components/embedding/interfaces";
import { SvgServer } from "@opal/icons";

View File

@@ -4,6 +4,7 @@ import { toast } from "@/hooks/useToast";
import EmbeddingModelSelection from "../EmbeddingModelSelectionForm";
import { useCallback, useEffect, useMemo, useState, useRef } from "react";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";

View File

@@ -8,6 +8,7 @@ import { ValidSources } from "@/lib/types";
import { FaCircleQuestion } from "react-icons/fa6";
import { CheckmarkIcon } from "@/components/icons/icons";
import { Button } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { cn } from "@/lib/utils";

View File

@@ -28,6 +28,7 @@ import Title from "@/components/ui/title";
import { redirect } from "next/navigation";
import { useIsKGExposed } from "@/app/admin/kg/utils";
import KGEntityTypes from "@/app/admin/kg/KGEntityTypes";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { SvgSettings } from "@opal/icons";

Some files were not shown because too many files have changed in this diff Show More