Compare commits

..

3 Commits

Author SHA1 Message Date
Raunak Bhagat
c079cd867a refactor: migrate 57 more refresh-components/Text files to Opal Text (batch 2) 2026-03-26 11:21:42 -07:00
Raunak Bhagat
f1c1473903 refactor: migrate 29 refresh-components/Text files to Opal Text (batch 1)
Convert boolean-flag Text API to string-enum props. Files with JSX
children, conditional boolean props, or className left with TODOs.
2026-03-26 11:06:29 -07:00
Raunak Bhagat
4a67cc0a09 chore: add migration TODOs to all 235 refresh-components/Text imports 2026-03-26 10:34:29 -07:00
307 changed files with 1301 additions and 5240 deletions

View File

@@ -1,35 +0,0 @@
"""remove voice_provider deleted column
Revision ID: 1d78c0ca7853
Revises: a3f8b2c1d4e5
Create Date: 2026-03-26 11:30:53.883127
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "1d78c0ca7853"
down_revision = "a3f8b2c1d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Hard-delete any soft-deleted rows before dropping the column
op.execute("DELETE FROM voice_provider WHERE deleted = true")
op.drop_column("voice_provider", "deleted")
def downgrade() -> None:
op.add_column(
"voice_provider",
sa.Column(
"deleted",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)

View File

@@ -28,7 +28,6 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ElementExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -188,6 +187,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
@@ -227,7 +227,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_permission_sync_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)

View File

@@ -29,7 +29,6 @@ from ee.onyx.external_permissions.sync_params import (
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.error_logging import emit_background_error
@@ -163,6 +162,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
@@ -221,7 +221,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_external_group_sync_fences(
tenant_id, self.app, r, r_replica, r_celery, lock_beat
)

View File

@@ -1,6 +1,5 @@
# These are helper objects for tracking the keys we need to write in redis
import json
import threading
from typing import Any
from typing import cast
@@ -8,59 +7,7 @@ from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
_broker_client: Redis | None = None
_broker_url: str | None = None
_broker_client_lock = threading.Lock()
def celery_get_broker_client(app: Celery) -> Redis:
"""Return a shared Redis client connected to the Celery broker DB.
Uses a module-level singleton so all tasks on a worker share one
connection instead of creating a new one per call. The client
connects directly to the broker Redis DB (parsed from the broker URL).
Thread-safe via lock — safe for use in Celery thread-pool workers.
Usage:
r_celery = celery_get_broker_client(self.app)
length = celery_get_queue_length(queue, r_celery)
"""
global _broker_client, _broker_url
with _broker_client_lock:
url = app.conf.broker_url
if _broker_client is not None and _broker_url == url:
try:
_broker_client.ping()
return _broker_client
except Exception:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
elif _broker_client is not None:
try:
_broker_client.close()
except Exception:
pass
_broker_client = None
_broker_url = url
_broker_client = Redis.from_url(
url,
decode_responses=False,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
retry_on_timeout=True,
)
return _broker_client
def celery_get_unacked_length(r: Redis) -> int:

View File

@@ -14,7 +14,6 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
@@ -133,6 +132,7 @@ def revoke_tasks_blocking_deletion(
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -149,7 +149,6 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
r_celery = celery_get_broker_client(self.app)
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -450,7 +449,7 @@ def check_indexing_completion(
):
# Check if the task exists in the celery queue
# This handles the case where Redis dies after task creation but before task execution
redis_celery = celery_get_broker_client(task.app)
redis_celery = task.app.broker_connection().channel().client # type: ignore
task_exists = celery_find_task(
attempt.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,

View File

@@ -1,5 +1,6 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
@@ -18,7 +19,6 @@ from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.memory_monitoring import emit_process_memory
@@ -698,27 +698,31 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
return None
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client()
# Collect queue metrics with broker connection
r_celery = celery_get_broker_client(self.app)
queue_metrics = _collect_queue_metrics(r_celery)
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
lambda: _collect_queue_metrics(redis_celery),
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect remaining metrics (no broker connection needed)
# Collect and log each metric
with get_session_with_current_tenant() as db_session:
all_metrics: list[Metric] = queue_metrics
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
for metric in all_metrics:
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")
except SoftTimeLimitExceeded:
@@ -886,7 +890,7 @@ def monitor_celery_queues_helper(
) -> None:
"""A task to monitor all celery queue lengths."""
r_celery = celery_get_broker_client(task.app)
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
@@ -1076,7 +1080,7 @@ def cloud_monitor_celery_pidbox(
num_deleted = 0
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
r_celery = celery_get_broker_client(self.app)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")

View File

@@ -17,7 +17,6 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
@@ -204,6 +203,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -261,7 +261,6 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
r_celery = celery_get_broker_client(self.app)
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
except Exception:
task_logger.exception("Exception while validating pruning fences")

View File

@@ -16,7 +16,6 @@ from sqlalchemy.orm import Session
from onyx.access.access import build_access_for_user_files
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
@@ -106,7 +105,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery = celery_get_broker_client(celery_app)
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
@@ -239,7 +238,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = celery_get_broker_client(self.app)
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
@@ -592,7 +591,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
# --- Protection 1: queue depth backpressure ---
# NOTE: must use the broker's Redis client (not redis_client) because
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
r_celery = celery_get_broker_client(self.app)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
task_logger.warning(

View File

@@ -12,11 +12,6 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
MASK_CREDENTIAL_CHAR = "\u2022"
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
SOURCE_TYPE = "source_type"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed

View File

@@ -10,7 +10,6 @@ from datetime import timedelta
from datetime import timezone
from typing import Any
import requests
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
@@ -240,53 +239,29 @@ def enhanced_search_ids(
)
def _bulk_fetch_request(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO: move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
# Prepare the payload according to Jira API v3 specification
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
# Only restrict fields if specified, might want to explicitly do this in the future
# to avoid reading unnecessary data
payload["fields"] = fields.split(",") if fields else ["*all"]
resp = jira_client._session.post(bulk_fetch_path, json=payload)
return resp.json()["issues"]
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
f"be decoded as JSON (response too large or truncated)."
)
raise
mid = len(issue_ids) // 2
logger.warning(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
raise e
return [
Issue(jira_client._options, jira_client._session, raw=issue)
for issue in raw_issues
for issue in response["issues"]
]

View File

@@ -4,6 +4,7 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -54,6 +55,7 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,6 +13,7 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -97,6 +98,11 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
@@ -131,47 +132,32 @@ def get_chat_sessions_by_user(
if before is not None:
stmt = stmt.where(ChatSession.time_updated < before)
if limit:
stmt = stmt.limit(limit)
if project_id is not None:
stmt = stmt.where(ChatSession.project_id == project_id)
elif only_non_project_chats:
stmt = stmt.where(ChatSession.project_id.is_(None))
# When filtering out failed chats, we apply the limit in Python after
# filtering rather than in SQL, since the post-filter may remove rows.
if limit and include_failed_chats:
stmt = stmt.limit(limit)
if not include_failed_chats:
non_system_message_exists_subq = (
exists()
.where(ChatMessage.chat_session_id == ChatSession.id)
.where(ChatMessage.message_type != MessageType.SYSTEM)
.correlate(ChatSession)
)
# Leeway for newly created chats that don't have messages yet
time = datetime.now(timezone.utc) - timedelta(minutes=5)
recently_created = ChatSession.time_created >= time
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
result = db_session.execute(stmt)
chat_sessions = list(result.scalars().all())
chat_sessions = result.scalars().all()
if not include_failed_chats and chat_sessions:
# Filter out "failed" sessions (those with only SYSTEM messages)
# using a separate efficient query instead of a correlated EXISTS
# subquery, which causes full sequential scans of chat_message.
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
if session_ids:
valid_session_ids_stmt = (
select(ChatMessage.chat_session_id)
.where(ChatMessage.chat_session_id.in_(session_ids))
.where(ChatMessage.message_type != MessageType.SYSTEM)
.distinct()
)
valid_session_ids = set(
db_session.execute(valid_session_ids_stmt).scalars().all()
)
chat_sessions = [
cs
for cs in chat_sessions
if cs.time_created >= leeway or cs.id in valid_session_ids
]
if limit:
chat_sessions = chat_sessions[:limit]
return chat_sessions
return list(chat_sessions)
def delete_orphaned_search_docs(db_session: Session) -> None:

View File

@@ -8,8 +8,6 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.constants import FederatedConnectorSource
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
@@ -47,23 +45,6 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
return fetch_all_federated_connectors(db_session)
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
"""Raise if any credential string value contains mask placeholder characters.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
Both must be rejected.
"""
for key, val in credentials.items():
if isinstance(val, str) and (
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
):
raise ValueError(
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
)
def validate_federated_connector_credentials(
source: FederatedConnectorSource,
credentials: dict[str, Any],
@@ -85,8 +66,6 @@ def create_federated_connector(
config: dict[str, Any] | None = None,
) -> FederatedConnector:
"""Create a new federated connector with credential and config validation."""
_reject_masked_credentials(credentials)
# Validate credentials before creating
if not validate_federated_connector_credentials(source, credentials):
raise ValueError(
@@ -298,8 +277,6 @@ def update_federated_connector(
)
if credentials is not None:
_reject_masked_credentials(credentials)
# Validate credentials before updating
if not validate_federated_connector_credentials(
federated_connector.source, credentials

View File

@@ -3135,6 +3135,8 @@ class VoiceProvider(Base):
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -46,6 +47,7 @@ async def fetch_user_for_pat(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
)
if not user:
return None

View File

@@ -229,9 +229,7 @@ def get_memories_for_user(
user_id: UUID,
db_session: Session,
) -> Sequence[Memory]:
return db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
).all()
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
def update_user_pinned_assistants(

View File

@@ -17,30 +17,39 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
db_session.scalars(
select(VoiceProvider)
.where(VoiceProvider.deleted.is_(False))
.order_by(VoiceProvider.name)
).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int
db_session: Session, provider_id: int, include_deleted: bool = False
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
if not include_deleted:
stmt = stmt.where(VoiceProvider.deleted.is_(False))
return db_session.scalar(stmt)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
select(VoiceProvider)
.where(VoiceProvider.is_default_stt.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
select(VoiceProvider)
.where(VoiceProvider.is_default_tts.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
@@ -49,7 +58,9 @@ def fetch_voice_provider_by_type(
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
select(VoiceProvider)
.where(VoiceProvider.provider_type == provider_type)
.where(VoiceProvider.deleted.is_(False))
)
@@ -108,10 +119,10 @@ def upsert_voice_provider(
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Delete a voice provider by ID."""
"""Soft-delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
db_session.delete(provider)
provider.deleted = True
db_session.flush()

View File

@@ -49,21 +49,9 @@ KNOWN_OPENPYXL_BUGS = [
def get_markitdown_converter() -> "MarkItDown":
global _MARKITDOWN_CONVERTER
from markitdown import MarkItDown
if _MARKITDOWN_CONVERTER is None:
from markitdown import MarkItDown
# Patch this function to effectively no-op because we were seeing this
# module take an inordinate amount of time to convert charts to markdown,
# making some powerpoint files with many or complicated charts nearly
# unindexable.
from markitdown.converters._pptx_converter import PptxConverter
setattr(
PptxConverter,
"_convert_chart_to_markdown",
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
)
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
return _MARKITDOWN_CONVERTER
@@ -214,26 +202,18 @@ def read_pdf_file(
try:
pdf_reader = PdfReader(file)
if pdf_reader.is_encrypted:
# Try the explicit password first, then fall back to an empty
# string. Owner-password-only PDFs (permission restrictions but
# no open password) decrypt successfully with "".
# See https://github.com/onyx-dot-app/onyx/issues/9754
passwords = [p for p in [pdf_pass, ""] if p is not None]
if pdf_reader.is_encrypted and pdf_pass is not None:
decrypt_success = False
for pw in passwords:
try:
if pdf_reader.decrypt(pw) != 0:
decrypt_success = True
break
except Exception:
pass
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error("Unable to decrypt pdf")
if not decrypt_success:
logger.error(
"Encrypted PDF could not be decrypted, returning empty text."
)
return "", metadata, []
elif pdf_reader.is_encrypted:
logger.warning("No Password for an encrypted PDF, returning empty text.")
return "", metadata, []
# Basic PDF metadata
if pdf_reader.metadata is not None:

View File

@@ -33,20 +33,8 @@ def is_pdf_protected(file: IO[Any]) -> bool:
with preserve_position(file):
reader = PdfReader(file)
if not reader.is_encrypted:
return False
# PDFs with only an owner password (permission restrictions like
# print/copy disabled) use an empty user password — any viewer can open
# them without prompting. decrypt("") returns 0 only when a real user
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
try:
return reader.decrypt("") == 0
except Exception:
logger.exception(
"Failed to evaluate PDF encryption; treating as password protected"
)
return True
return bool(reader.is_encrypted)
def is_docx_protected(file: IO[Any]) -> bool:

View File

@@ -185,21 +185,6 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
return False
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
"""Check if the prompt contains any assistant messages with tool_calls.
When Anthropic's extended thinking is enabled, the API requires every
assistant message to start with a thinking block before any tool_use
blocks. Since we don't preserve thinking_blocks (they carry
cryptographic signatures that can't be reconstructed), we must skip
the thinking param whenever history contains prior tool-calling turns.
"""
from onyx.llm.models import AssistantMessage
msgs = prompt if isinstance(prompt, list) else [prompt]
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
@@ -481,20 +466,7 @@ class LitellmLLM(LLM):
reasoning_effort
)
# Anthropic requires every assistant message with tool_use
# blocks to start with a thinking block that carries a
# cryptographic signature. We don't preserve those blocks
# across turns, so skip thinking when the history already
# contains tool-calling assistant messages. LiteLLM's
# modify_params workaround doesn't cover all providers
# (notably Bedrock).
can_enable_thinking = (
budget_tokens is not None
and not _prompt_contains_tool_call_history(prompt)
)
if can_enable_thinking:
assert budget_tokens is not None # mypy
if budget_tokens is not None:
if max_tokens is not None:
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
# and the minimum budget tokens is 1024

View File

@@ -6,7 +6,6 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
from onyx.configs.app_configs import MCP_SERVER_HOST
from onyx.configs.app_configs import MCP_SERVER_PORT
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
logger = setup_logger()
@@ -17,7 +16,6 @@ def main() -> None:
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
return
set_is_ee_based_on_env_variable()
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
from onyx.mcp_server.api import mcp_app

View File

@@ -44,12 +44,11 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
@@ -142,11 +141,19 @@ def _validate_endpoint(
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# Any timeout (connect, read, or write) means the configured timeout_seconds
# is too low for this endpoint. Report as timeout so the UI directs the user
# to increase the timeout setting.
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: timeout for %s",
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)

View File

@@ -1524,7 +1524,6 @@ def get_bifrost_available_models(
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:

View File

@@ -463,4 +463,3 @@ class BifrostFinalModelResponse(BaseModel):
display_name: str # Human-readable name from Bifrost API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -147,7 +147,6 @@ class UserInfo(BaseModel):
is_anonymous_user: bool | None = None,
tenant_info: TenantInfo | None = None,
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
memories: list[MemoryItem] | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
@@ -192,7 +191,10 @@ class UserInfo(BaseModel):
role=user.personal_role or "",
use_memories=user.use_memories,
enable_memory_tool=user.enable_memory_tool,
memories=memories or [],
memories=[
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in (user.memories or [])
],
user_preferences=user.user_preferences or "",
),
)

View File

@@ -57,7 +57,6 @@ from onyx.db.user_preferences import activate_user
from onyx.db.user_preferences import deactivate_user
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
from onyx.db.user_preferences import get_latest_access_token_for_user
from onyx.db.user_preferences import get_memories_for_user
from onyx.db.user_preferences import update_assistant_preferences
from onyx.db.user_preferences import update_user_assistant_visibility
from onyx.db.user_preferences import update_user_auto_scroll
@@ -824,11 +823,6 @@ def verify_user_logged_in(
[],
),
)
memories = [
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in get_memories_for_user(user.id, db_session)
]
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
@@ -839,7 +833,6 @@ def verify_user_logged_in(
new_tenant=new_tenant,
invitation=tenant_invitation,
),
memories=memories,
)
return user_info
@@ -937,8 +930,7 @@ def update_user_personalization_api(
else user.enable_memory_tool
)
existing_memories = [
MemoryItem(id=memory.id, content=memory.memory_text)
for memory in get_memories_for_user(user.id, db_session)
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
]
new_memories = (
request.memories if request.memories is not None else existing_memories

View File

@@ -12,6 +12,7 @@ stale, which is fine for monitoring dashboards.
import json
import threading
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -103,23 +104,25 @@ class _CachedCollector(Collector):
class QueueDepthCollector(_CachedCollector):
"""Reads Celery queue lengths from the broker Redis on each scrape."""
"""Reads Celery queue lengths from the broker Redis on each scrape.
Uses a Redis client factory (callable) rather than a stored client
reference so the connection is always fresh from Celery's pool.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
self._get_redis: Callable[[], Redis] | None = None
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
if self._get_redis is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
redis_client = self._get_redis()
depth = GaugeMetricFamily(
"onyx_queue_depth",
@@ -401,19 +404,17 @@ class RedisHealthCollector(_CachedCollector):
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
self._get_redis: Callable[[], Redis] | None = None
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app for broker Redis access."""
self._celery_app = app
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
if self._get_redis is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
redis_client = self._get_redis()
memory_used = GaugeMetricFamily(
"onyx_redis_memory_used_bytes",

View File

@@ -3,8 +3,12 @@
Called once by the monitoring celery worker after Redis and DB are ready.
"""
from collections.abc import Callable
from typing import Any
from celery import Celery
from prometheus_client.registry import REGISTRY
from redis import Redis
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
@@ -17,7 +21,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# Module-level singletons — these are lightweight objects (no connections or DB
# state) until configure() / set_celery_app() is called. Keeping them at
# state) until configure() / set_redis_factory() is called. Keeping them at
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
@@ -28,15 +32,72 @@ _worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
"""Create a factory that returns a cached broker Redis client.
Reuses a single connection across scrapes to avoid leaking connections.
Reconnects automatically if the cached connection becomes stale.
"""
_cached_client: list[Redis | None] = [None]
# Keep a reference to the Kombu Connection so we can close it on
# reconnect (the raw Redis client outlives the Kombu wrapper).
_cached_kombu_conn: list[Any] = [None]
def _close_client(client: Redis) -> None:
"""Best-effort close of a Redis client."""
try:
client.close()
except Exception:
logger.debug("Failed to close stale Redis client", exc_info=True)
def _close_kombu_conn() -> None:
"""Best-effort close of the cached Kombu Connection."""
conn = _cached_kombu_conn[0]
if conn is not None:
try:
conn.close()
except Exception:
logger.debug("Failed to close Kombu connection", exc_info=True)
_cached_kombu_conn[0] = None
def _get_broker_redis() -> Redis:
client = _cached_client[0]
if client is not None:
try:
client.ping()
return client
except Exception:
logger.debug("Cached Redis client stale, reconnecting")
_close_client(client)
_cached_client[0] = None
_close_kombu_conn()
# Get a fresh Redis client from the broker connection.
# We hold this client long-term (cached above) rather than using a
# context manager, because we need it to persist across scrapes.
# The caching logic above ensures we only ever hold one connection,
# and we close it explicitly on reconnect.
conn = celery_app.broker_connection()
# kombu's Channel exposes .client at runtime (the underlying Redis
# client) but the type stubs don't declare it.
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
_cached_client[0] = new_client
_cached_kombu_conn[0] = conn
return new_client
return _get_broker_redis
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
"""Register all indexing pipeline collectors with the default registry.
Args:
celery_app: The Celery application instance. Used to obtain a
celery_app: The Celery application instance. Used to obtain a fresh
broker Redis client on each scrape for queue depth metrics.
"""
_queue_collector.set_celery_app(celery_app)
_redis_health_collector.set_celery_app(celery_app)
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
# Start the heartbeat monitor daemon thread — uses a single persistent
# connection to receive worker-heartbeat events.

View File

@@ -129,10 +129,6 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
return_value=mock_app,
),
patch(_PATCH_QUEUE_DEPTH, return_value=0),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
):
yield

View File

@@ -88,22 +88,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
Also patches ``celery_get_broker_client`` so the mock app doesn't need
a real broker URL.
"""
task_instance = task.run.__self__
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield

View File

@@ -90,17 +90,8 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
task only.
"""
task_instance = task.run.__self__
with (
patch.object(
type(task_instance),
"app",
new_callable=PropertyMock,
return_value=mock_app,
),
patch(
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
return_value=MagicMock(),
),
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield

View File

@@ -103,11 +103,6 @@ _EXPECTED_CONFLUENCE_GROUPS = [
user_emails={"oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="no yuhong allowed",
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
gives_anyone_access=False,
),
]

View File

@@ -1,58 +0,0 @@
import pytest
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.db.federated import _reject_masked_credentials
class TestRejectMaskedCredentials:
"""Verify that masked credential values are never accepted for DB writes.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
_reject_masked_credentials must catch both.
"""
def test_rejects_fully_masked_value(self) -> None:
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
with pytest.raises(ValueError, match="masked placeholder"):
_reject_masked_credentials({"client_id": masked})
def test_rejects_long_string_masked_value(self) -> None:
"""mask_string returns 'first4...last4' for long strings — the real
format used for OAuth credentials like client_id and client_secret."""
with pytest.raises(ValueError, match="masked placeholder"):
_reject_masked_credentials({"client_id": "1234...7890"})
def test_rejects_when_any_field_is_masked(self) -> None:
"""Even if client_id is real, a masked client_secret must be caught."""
with pytest.raises(ValueError, match="client_secret"):
_reject_masked_credentials(
{
"client_id": "1234567890.1234567890",
"client_secret": MASK_CREDENTIAL_CHAR * 12,
}
)
def test_accepts_real_credentials(self) -> None:
# Should not raise
_reject_masked_credentials(
{
"client_id": "1234567890.1234567890",
"client_secret": "test_client_secret_value",
}
)
def test_accepts_empty_dict(self) -> None:
# Should not raise — empty credentials are handled elsewhere
_reject_masked_credentials({})
def test_ignores_non_string_values(self) -> None:
# Non-string values (None, bool, int) should pass through
_reject_masked_credentials(
{
"client_id": "real_value",
"redirect_uri": None,
"some_flag": True,
}
)

View File

@@ -1,87 +0,0 @@
"""Tests for celery_get_broker_client singleton."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.background.celery import celery_redis
@pytest.fixture(autouse=True)
def reset_singleton() -> Iterator[None]:
"""Reset the module-level singleton between tests."""
celery_redis._broker_client = None
celery_redis._broker_url = None
yield
celery_redis._broker_client = None
celery_redis._broker_url = None
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
app = MagicMock()
app.conf.broker_url = broker_url
return app
class TestCeleryGetBrokerClient:
@patch("onyx.background.celery.celery_redis.Redis")
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
result = celery_redis.celery_get_broker_client(app)
assert result is mock_client
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://localhost:6379/15"
assert call_args[1]["decode_responses"] is False
assert call_args[1]["socket_keepalive"] is True
assert call_args[1]["retry_on_timeout"] is True
@patch("onyx.background.celery.celery_redis.Redis")
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_cls.from_url.return_value = mock_client
app = _make_mock_app()
client1 = celery_redis.celery_get_broker_client(app)
client2 = celery_redis.celery_get_broker_client(app)
assert client1 is client2
# from_url called only once
assert mock_redis_cls.from_url.call_count == 1
@patch("onyx.background.celery.celery_redis.Redis")
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
stale_client = MagicMock()
stale_client.ping.side_effect = ConnectionError("disconnected")
fresh_client = MagicMock()
fresh_client.ping.return_value = True
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
app = _make_mock_app()
# First call creates stale_client
client1 = celery_redis.celery_get_broker_client(app)
assert client1 is stale_client
# Second call: ping fails, creates fresh_client
client2 = celery_redis.celery_get_broker_client(app)
assert client2 is fresh_client
assert mock_redis_cls.from_url.call_count == 2
@patch("onyx.background.celery.celery_redis.Redis")
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
mock_redis_cls.from_url.return_value = MagicMock()
app = _make_mock_app("redis://custom-host:6380/3")
celery_redis.celery_get_broker_client(app)
call_args = mock_redis_cls.from_url.call_args
assert call_args[0][0] == "redis://custom-host:6380/3"

View File

@@ -1,147 +0,0 @@
from typing import Any
from unittest.mock import MagicMock
import pytest
import requests
from jira import JIRA
from jira.resources import Issue
from onyx.connectors.jira.connector import bulk_fetch_issues
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
return {
"id": issue_id,
"key": f"TEST-{issue_id}",
"fields": {"summary": f"Issue {issue_id}"},
}
def _mock_jira_client() -> MagicMock:
mock = MagicMock(spec=JIRA)
mock._options = {"server": "https://jira.example.com"}
mock._session = MagicMock()
mock._get_url = MagicMock(
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
)
return mock
def test_bulk_fetch_success() -> None:
"""Happy path: all issues fetched in one request."""
client = _mock_jira_client()
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
result = bulk_fetch_issues(client, ["1", "2", "3"])
assert len(result) == 3
assert all(isinstance(r, Issue) for r in result)
client._session.post.assert_called_once()
def test_bulk_fetch_splits_on_json_error() -> None:
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
client = _mock_jira_client()
call_count = 0
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
nonlocal call_count
call_count += 1
ids = json["issueIdsOrKeys"]
if len(ids) > 2:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 2294125
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
assert len(result) == 4
returned_ids = {r.raw["id"] for r in result}
assert returned_ids == {"1", "2", "3", "4"}
assert call_count > 1
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
"""A single issue that always fails JSON decode raises after splitting."""
client = _mock_jira_client()
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if "bad" in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 100
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "bad", "2"])
def test_bulk_fetch_non_json_error_propagates() -> None:
"""Non-JSONDecodeError exceptions still propagate."""
client = _mock_jira_client()
resp = MagicMock()
resp.json.side_effect = ValueError("something else broke")
client._session.post.return_value = resp
try:
bulk_fetch_issues(client, ["1"])
assert False, "Expected ValueError to propagate"
except ValueError:
pass
def test_bulk_fetch_with_fields() -> None:
"""Fields parameter is forwarded correctly."""
client = _mock_jira_client()
raw = [_make_raw_issue("1")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
bulk_fetch_issues(client, ["1"], fields="summary,description")
call_payload = client._session.post.call_args[1]["json"]
assert call_payload["fields"] == ["summary", "description"]
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
client = _mock_jira_client()
bad_id = "BAD"
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if bad_id in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"truncated", "doc", 999
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])

View File

@@ -1,225 +0,0 @@
"""Tests for get_chat_sessions_by_user filtering behavior.
Verifies that failed chat sessions (those with only SYSTEM messages) are
correctly filtered out while preserving recently created sessions, matching
the behavior specified in PR #7233.
"""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from uuid import UUID
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.models import ChatSession
def _make_session(
user_id: UUID,
time_created: datetime | None = None,
time_updated: datetime | None = None,
description: str = "",
) -> MagicMock:
"""Create a mock ChatSession with the given attributes."""
session = MagicMock(spec=ChatSession)
session.id = uuid4()
session.user_id = user_id
session.time_created = time_created or datetime.now(timezone.utc)
session.time_updated = time_updated or session.time_created
session.description = description
session.deleted = False
session.onyxbot_flow = False
session.project_id = None
return session
@pytest.fixture
def user_id() -> UUID:
return uuid4()
@pytest.fixture
def old_time() -> datetime:
"""A timestamp well outside the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(hours=1)
@pytest.fixture
def recent_time() -> datetime:
"""A timestamp within the 5-minute leeway window."""
return datetime.now(timezone.utc) - timedelta(minutes=2)
class TestGetChatSessionsByUser:
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
def test_filters_out_failed_sessions(
self, user_id: UUID, old_time: datetime
) -> None:
"""Sessions with only SYSTEM messages should be excluded."""
valid_session = _make_session(user_id, time_created=old_time)
failed_session = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
# First execute: returns all sessions
# Second execute: returns only the valid session's ID (has non-system msgs)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
valid_session,
failed_session,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == valid_session.id
def test_keeps_recent_sessions_without_messages(
self, user_id: UUID, recent_time: datetime
) -> None:
"""Recently created sessions should be kept even without messages."""
recent_session = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [recent_session]
db_session.execute.side_effect = [mock_result_1]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert len(result) == 1
assert result[0].id == recent_session.id
# Should only have been called once — no second query needed
# because the recent session is within the leeway window
assert db_session.execute.call_count == 1
def test_include_failed_chats_skips_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""When include_failed_chats=True, no filtering should occur."""
session_a = _make_session(user_id, time_created=old_time)
session_b = _make_session(user_id, time_created=old_time)
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=True,
)
assert len(result) == 2
# Only one DB call — no second query for message validation
assert db_session.execute.call_count == 1
def test_limit_applied_after_filtering(
self, user_id: UUID, old_time: datetime
) -> None:
"""Limit should be applied after filtering, not before."""
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
valid_ids = [s.id for s in sessions[:3]]
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = sessions
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = valid_ids
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
limit=2,
)
assert len(result) == 2
# Should be the first 2 valid sessions (order preserved)
assert result[0].id == sessions[0].id
assert result[1].id == sessions[1].id
def test_mixed_recent_and_old_sessions(
self, user_id: UUID, old_time: datetime, recent_time: datetime
) -> None:
"""Mix of recent and old sessions should filter correctly."""
old_valid = _make_session(user_id, time_created=old_time)
old_failed = _make_session(user_id, time_created=old_time)
recent_no_msgs = _make_session(user_id, time_created=recent_time)
db_session = MagicMock(spec=Session)
mock_result_1 = MagicMock()
mock_result_1.scalars.return_value.all.return_value = [
old_valid,
old_failed,
recent_no_msgs,
]
mock_result_2 = MagicMock()
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
db_session.execute.side_effect = [mock_result_1, mock_result_2]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
result_ids = {cs.id for cs in result}
assert old_valid.id in result_ids
assert recent_no_msgs.id in result_ids
assert old_failed.id not in result_ids
def test_empty_result(self, user_id: UUID) -> None:
"""No sessions should return empty list without errors."""
db_session = MagicMock(spec=Session)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
db_session.execute.side_effect = [mock_result]
result = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
include_failed_chats=False,
)
assert result == []
assert db_session.execute.call_count == 1

View File

@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
class TestDeleteVoiceProvider:
"""Tests for delete_voice_provider."""
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
delete_voice_provider(mock_db_session, 1)
mock_db_session.delete.assert_called_once_with(provider)
assert provider.deleted is True
mock_db_session.flush.assert_called_once()
def test_does_nothing_when_provider_not_found(

View File

@@ -1,76 +0,0 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer <1083d595b1>
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 42
>>
stream
,N<><6~<7E>)<29><><EFBFBD><EFBFBD><EFBFBD>u<EFBFBD> <0C><><EFBFBD>Zc'<27><>>8g<38><67><EFBFBD>n<EFBFBD><6E><EFBFBD><EFBFBD><EFBFBD>9"
endstream
endobj
6 0 obj
<<
/V 2
/R 3
/Length 128
/P 4294967292
/Filter /Standard
/O <6a340a292629053da84a6d8b19a5d505953b8b3fdac3d2d389fde0e354528d44>
/U <d6f0dc91c7b9de264a8d708515468e6528bf4e5e4e758a4164004e56fffa0108>
>>
endobj
xref
0 7
0000000000 65535 f
0000000015 00000 n
0000000059 00000 n
0000000118 00000 n
0000000167 00000 n
0000000348 00000 n
0000000440 00000 n
trailer
<<
/Size 7
/Root 3 0 R
/Info 1 0 R
/ID [ <6364336635356135633239323638353039306635656133623165313637366430> <6364336635356135633239323638353039306635656133623165313637366430> ]
/Encrypt 6 0 R
>>
startxref
655
%%EOF

View File

@@ -54,12 +54,6 @@ class TestReadPdfFile:
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
assert text == ""
def test_owner_password_only_pdf_extracts_text(self) -> None:
"""A PDF encrypted with only an owner password (no user password)
should still yield its text content. Regression for #9754."""
text, _, _ = read_pdf_file(_load("owner_protected.pdf"))
assert "Hello World" in text
def test_empty_pdf(self) -> None:
text, _, _ = read_pdf_file(_load("empty.pdf"))
assert text.strip() == ""
@@ -123,12 +117,6 @@ class TestIsPdfProtected:
def test_protected_pdf(self) -> None:
assert is_pdf_protected(_load("encrypted.pdf")) is True
def test_owner_password_only_is_not_protected(self) -> None:
"""A PDF with only an owner password (permission restrictions) but no
user password should NOT be considered protected — any viewer can open
it without prompting for a password."""
assert is_pdf_protected(_load("owner_protected.pdf")) is False
def test_preserves_file_position(self) -> None:
pdf = _load("simple.pdf")
pdf.seek(42)

View File

@@ -1,79 +0,0 @@
import io
from pptx import Presentation # type: ignore[import-untyped]
from pptx.chart.data import CategoryChartData # type: ignore[import-untyped]
from pptx.enum.chart import XL_CHART_TYPE # type: ignore[import-untyped]
from pptx.util import Inches # type: ignore[import-untyped]
from onyx.file_processing.extract_file_text import pptx_to_text
def _make_pptx_with_chart() -> io.BytesIO:
"""Create an in-memory pptx with one text slide and one chart slide."""
prs = Presentation()
# Slide 1: text only
slide1 = prs.slides.add_slide(prs.slide_layouts[1])
slide1.shapes.title.text = "Introduction"
slide1.placeholders[1].text = "This is the first slide."
# Slide 2: chart
slide2 = prs.slides.add_slide(prs.slide_layouts[5]) # Blank layout
chart_data = CategoryChartData()
chart_data.categories = ["Q1", "Q2", "Q3"]
chart_data.add_series("Revenue", (100, 200, 300))
slide2.shapes.add_chart(
XL_CHART_TYPE.COLUMN_CLUSTERED,
Inches(1),
Inches(1),
Inches(6),
Inches(4),
chart_data,
)
buf = io.BytesIO()
prs.save(buf)
buf.seek(0)
return buf
def _make_pptx_without_chart() -> io.BytesIO:
"""Create an in-memory pptx with a single text-only slide."""
prs = Presentation()
slide = prs.slides.add_slide(prs.slide_layouts[1])
slide.shapes.title.text = "Hello World"
slide.placeholders[1].text = "Some content here."
buf = io.BytesIO()
prs.save(buf)
buf.seek(0)
return buf
class TestPptxToText:
def test_chart_is_omitted(self) -> None:
# Precondition
pptx_file = _make_pptx_with_chart()
# Under test
result = pptx_to_text(pptx_file)
# Postcondition
assert "Introduction" in result
assert "first slide" in result
assert "[chart omitted]" in result
# The actual chart data should NOT appear in the output.
assert "Revenue" not in result
assert "Q1" not in result
def test_text_only_pptx(self) -> None:
# Precondition
pptx_file = _make_pptx_without_chart()
# Under test
result = pptx_to_text(pptx_file)
# Postcondition
assert "Hello World" in result
assert "Some content" in result
assert "[chart omitted]" not in result

View File

@@ -11,7 +11,6 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
from litellm.types.utils import Delta
from litellm.types.utils import Function as LiteLLMFunction
import onyx.llm.models
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLMUserIdentity
@@ -1480,147 +1479,6 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
def test_prompt_contains_tool_call_history_true() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="tc_1",
function=FunctionCall(name="get_weather", arguments="{}"),
)
],
),
]
assert _prompt_contains_tool_call_history(messages) is True
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [
UserMessage(content="Hello"),
AssistantMessage(content="Hi there!"),
]
assert _prompt_contains_tool_call_history(messages) is False
def test_prompt_contains_tool_call_history_false_user_only() -> None:
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
messages: LanguageModelInput = [UserMessage(content="Hello")]
assert _prompt_contains_tool_call_history(messages) is False
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
"""When thinking is enabled but assistant messages with tool_calls lack
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
BadRequestError about missing thinking blocks."""
llm = LitellmLLM(
api_key=None,
timeout=30,
model_provider=LlmProviderNames.BEDROCK,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
max_input_tokens=200000,
)
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="tc_1",
function=FunctionCall(
name="get_weather",
arguments='{"city": "Paris"}',
),
)
],
),
onyx.llm.models.ToolMessage(
content="22°C sunny",
tool_call_id="tc_1",
),
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
kwargs = mock_completion.call_args.kwargs
assert "thinking" not in kwargs, (
"thinking param should be dropped when thinking_blocks are missing "
"from assistant messages with tool_calls"
)
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
"""When thinking is enabled and there are no historical assistant messages
with tool_calls, the thinking param should be preserved."""
llm = LitellmLLM(
api_key=None,
timeout=30,
model_provider=LlmProviderNames.BEDROCK,
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
max_input_tokens=200000,
)
messages: LanguageModelInput = [
UserMessage(content="What's the weather?"),
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
]
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
kwargs = mock_completion.call_args.kwargs
assert "thinking" in kwargs, (
"thinking param should be preserved when no assistant messages "
"with tool_calls exist in history"
)
assert kwargs["thinking"]["type"] == "enabled"
def test_bifrost_claude_includes_allowed_openai_params() -> None:
llm = LitellmLLM(
api_key="test_key",

View File

@@ -3,7 +3,7 @@
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
# ---------------------------------------------------------------------------
@@ -158,11 +158,13 @@ class TestValidateEndpoint:
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.timeout
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(

View File

@@ -1,6 +1,5 @@
"""Tests for indexing pipeline Prometheus collectors."""
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -14,16 +13,6 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
@pytest.fixture(autouse=True)
def _mock_broker_client() -> Iterator[None]:
"""Patch celery_get_broker_client for all collector tests."""
with patch(
"onyx.background.celery.celery_redis.celery_get_broker_client",
return_value=MagicMock(),
):
yield
class TestQueueDepthCollector:
def test_returns_empty_when_factory_not_set(self) -> None:
collector = QueueDepthCollector()
@@ -35,7 +24,8 @@ class TestQueueDepthCollector:
def test_collects_queue_depths(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
collector.set_celery_app(MagicMock())
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with (
patch(
@@ -70,8 +60,8 @@ class TestQueueDepthCollector:
def test_handles_redis_error_gracefully(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
MagicMock()
collector.set_celery_app(MagicMock())
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
@@ -84,8 +74,8 @@ class TestQueueDepthCollector:
def test_caching_returns_stale_within_ttl(self) -> None:
collector = QueueDepthCollector(cache_ttl=60)
MagicMock()
collector.set_celery_app(MagicMock())
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with (
patch(
@@ -108,10 +98,31 @@ class TestQueueDepthCollector:
assert first is second # Same object, from cache
def test_factory_called_each_scrape(self) -> None:
"""Verify the Redis factory is called on each fresh collect, not cached."""
collector = QueueDepthCollector(cache_ttl=0)
factory = MagicMock(return_value=MagicMock())
collector.set_redis_factory(factory)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=0,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
collector.collect()
collector.collect()
assert factory.call_count == 2
def test_error_returns_stale_cache(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
MagicMock()
collector.set_celery_app(MagicMock())
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
# First call succeeds
with (

View File

@@ -1,22 +1,96 @@
"""Tests for indexing pipeline setup."""
"""Tests for indexing pipeline setup (Redis factory caching)."""
from unittest.mock import MagicMock
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
class TestCollectorCeleryAppSetup:
def test_queue_depth_collector_uses_celery_app(self) -> None:
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
collector = QueueDepthCollector()
mock_app = MagicMock()
collector.set_celery_app(mock_app)
assert collector._celery_app is mock_app
def _make_mock_app(client: MagicMock) -> MagicMock:
"""Create a mock Celery app whose broker_connection().channel().client
returns the given client."""
mock_app = MagicMock()
mock_conn = MagicMock()
mock_conn.channel.return_value.client = client
def test_redis_health_collector_uses_celery_app(self) -> None:
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
collector = RedisHealthCollector()
mock_app = MagicMock()
collector.set_celery_app(mock_app)
assert collector._celery_app is mock_app
mock_app.broker_connection.return_value = mock_conn
return mock_app
class TestMakeBrokerRedisFactory:
def test_caches_redis_client_across_calls(self) -> None:
"""Factory should reuse the same client on subsequent calls."""
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client1 = factory()
client2 = factory()
assert client1 is client2
# broker_connection should only be called once
assert mock_app.broker_connection.call_count == 1
def test_reconnects_when_ping_fails(self) -> None:
"""Factory should create a new client if ping fails (stale connection)."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
client1 = factory()
assert client1 is mock_client_stale
assert mock_app.broker_connection.call_count == 1
# Switch to fresh client for next connection
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails on stale, reconnects
client2 = factory()
assert client2 is mock_client_fresh
assert mock_app.broker_connection.call_count == 2
def test_reconnect_closes_stale_client(self) -> None:
"""When ping fails, the old client should be closed before reconnecting."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
factory()
# Switch to fresh client
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails, should close stale client
factory()
mock_client_stale.close.assert_called_once()
def test_first_call_creates_connection(self) -> None:
"""First call should always create a new connection."""
mock_client = MagicMock()
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client = factory()
assert client is mock_client
mock_app.broker_connection.assert_called_once()

View File

@@ -70,10 +70,6 @@ backend = [
"lazy_imports==1.0.1",
"lxml==5.3.0",
"Mako==1.2.4",
# NOTE: Do not update without understanding the patching behavior in
# get_markitdown_converter in
# backend/onyx/file_processing/extract_file_text.py and what impacts
# updating might have on this behavior.
"markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2",
"mcp[cli]==1.26.0",
"msal==1.34.0",

View File

@@ -1,22 +0,0 @@
import { cn } from "@opal/utils";
import type { IconProps } from "@opal/types";
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 37 46"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className={cn(className, "text-[#33C19E] dark:text-white")}
{...props}
>
<title>Bifrost</title>
<path
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
fill="currentColor"
/>
</svg>
);
export default SvgBifrost;

View File

@@ -24,7 +24,6 @@ export { default as SvgAzure } from "@opal/icons/azure";
export { default as SvgBarChart } from "@opal/icons/bar-chart";
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
export { default as SvgBell } from "@opal/icons/bell";
export { default as SvgBifrost } from "@opal/icons/bifrost";
export { default as SvgBlocks } from "@opal/icons/blocks";
export { default as SvgBookOpen } from "@opal/icons/book-open";
export { default as SvgBookmark } from "@opal/icons/bookmark";

View File

@@ -31,6 +31,7 @@ import { Credential } from "@/lib/connectors/credentials";
import { SettingsContext } from "@/providers/SettingsProvider";
import SourceTile from "@/components/SourceTile";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { ADMIN_ROUTES } from "@/lib/admin-routes";

View File

@@ -4,7 +4,7 @@ import { createApiKey, updateApiKey } from "./lib";
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { FormikField } from "@/refresh-components/form/FormikField";
@@ -79,7 +79,7 @@ export default function OnyxApiKeyForm({
{({ isSubmitting }) => (
<Form className="w-full overflow-visible">
<Modal.Body>
<Text as="p">
<Text as="p" color="text-05">
Choose a memorable name for your API key. This is optional and
can be added or changed later!
</Text>

View File

@@ -29,12 +29,10 @@ import {
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
import Message from "@/refresh-components/messages/Message";
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
import { useBillingInformation } from "@/hooks/useBillingInformation";
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
const route = ADMIN_ROUTES.API_KEYS;
@@ -47,11 +45,6 @@ function Main() {
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
const canCreateKeys = useCloudSubscription();
const { data: billingData } = useBillingInformation();
const isTrialing =
billingData !== undefined &&
hasActiveSubscription(billingData) &&
billingData.status === BillingStatus.TRIALING;
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
@@ -83,16 +76,6 @@ function Main() {
const introSection = (
<div className="flex flex-col items-start gap-4">
{isTrialing && (
<Message
static
warning
close={false}
className="w-full"
text="Upgrade to a paid plan to create API keys."
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
)}
<Text as="p">
API Keys allow you to access Onyx APIs programmatically.
{canCreateKeys
@@ -103,9 +86,23 @@ function Main() {
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
Create API Key
</CreateButton>
) : isTrialing ? (
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
) : null}
) : (
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
<div className="flex items-center gap-1.5">
<Text as="p" text04>
Upgrade to a paid plan to create API keys.
</Text>
<Button
variant="none"
prominence="tertiary"
size="2xs"
icon={SvgInfo}
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
</div>
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
</div>
)}
</div>
);

View File

@@ -9,6 +9,7 @@ import Card from "@/refresh-components/cards/Card";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { Disabled } from "@opal/core";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Message from "@/refresh-components/messages/Message";
import InfoBlock from "@/refresh-components/messages/InfoBlock";

View File

@@ -5,6 +5,7 @@ import { Section } from "@/layouts/general-layouts";
import * as InputLayouts from "@/layouts/input-layouts";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Card from "@/refresh-components/cards/Card";
import Separator from "@/refresh-components/Separator";

View File

@@ -4,6 +4,7 @@ import { useState } from "react";
import Card from "@/refresh-components/cards/Card";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import InputFile from "@/refresh-components/inputs/InputFile";
import { Section } from "@/layouts/general-layouts";

View File

@@ -22,6 +22,7 @@ import type { IconProps } from "@opal/types";
import Card from "@/refresh-components/cards/Card";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Section } from "@/layouts/general-layouts";

View File

@@ -1,387 +0,0 @@
/**
* Tests for BillingPage handleBillingReturn retry logic.
*
* The retry logic retries claimLicense up to 3 times with 2s backoff
* when returning from a Stripe checkout session. This prevents the user
* from getting stranded when the Stripe webhook fires concurrently with
* the browser redirect and the license isn't ready yet.
*/
import React from "react";
import { render, screen, waitFor } from "@tests/setup/test-utils";
import { act } from "@testing-library/react";
// ---- Stable mock objects (must be named with mock* prefix for jest hoisting) ----
// useRouter and useSearchParams must return the SAME reference each call, otherwise
// React's useEffect sees them as changed and re-runs the effect on every render.
const mockRouter = {
replace: jest.fn() as jest.Mock,
refresh: jest.fn() as jest.Mock,
};
const mockSearchParams = {
get: jest.fn() as jest.Mock,
};
const mockClaimLicense = jest.fn() as jest.Mock;
const mockRefreshBilling = jest.fn() as jest.Mock;
const mockRefreshLicense = jest.fn() as jest.Mock;
// ---- Mocks ----
jest.mock("next/navigation", () => ({
useRouter: () => mockRouter,
useSearchParams: () => mockSearchParams,
}));
jest.mock("@/layouts/settings-layouts", () => ({
Root: ({ children }: { children: React.ReactNode }) => (
<div data-testid="settings-root">{children}</div>
),
Header: () => <div data-testid="settings-header" />,
Body: ({ children }: { children: React.ReactNode }) => (
<div data-testid="settings-body">{children}</div>
),
}));
jest.mock("@/layouts/general-layouts", () => ({
Section: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
jest.mock("@opal/icons", () => ({
SvgArrowUpCircle: () => <svg />,
SvgWallet: () => <svg />,
}));
jest.mock("./PlansView", () => ({
__esModule: true,
default: () => <div data-testid="plans-view" />,
}));
jest.mock("./CheckoutView", () => ({
__esModule: true,
default: () => <div data-testid="checkout-view" />,
}));
jest.mock("./BillingDetailsView", () => ({
__esModule: true,
default: () => <div data-testid="billing-details-view" />,
}));
jest.mock("./LicenseActivationCard", () => ({
__esModule: true,
default: () => <div data-testid="license-activation-card" />,
}));
jest.mock("@/refresh-components/messages/Message", () => ({
__esModule: true,
default: ({
text,
description,
onClose,
}: {
text: string;
description?: string;
onClose?: () => void;
}) => (
<div data-testid="activating-banner">
<span data-testid="activating-banner-text">{text}</span>
{description && (
<span data-testid="activating-banner-description">{description}</span>
)}
{onClose && (
<button data-testid="activating-banner-close" onClick={onClose}>
Close
</button>
)}
</div>
),
}));
jest.mock("@/lib/billing", () => ({
useBillingInformation: jest.fn(),
useLicense: jest.fn(),
hasActiveSubscription: jest.fn().mockReturnValue(false),
claimLicense: (...args: unknown[]) => mockClaimLicense(...args),
}));
jest.mock("@/lib/constants", () => ({
NEXT_PUBLIC_CLOUD_ENABLED: false,
}));
// ---- Import after mocks ----
import BillingPage from "./page";
import { useBillingInformation, useLicense } from "@/lib/billing";
// ---- Test helpers ----
function setupHooks() {
(useBillingInformation as jest.Mock).mockReturnValue({
data: null,
isLoading: false,
error: null,
refresh: mockRefreshBilling,
});
(useLicense as jest.Mock).mockReturnValue({
data: null,
isLoading: false,
refresh: mockRefreshLicense,
});
}
// ---- Tests ----
describe("BillingPage — handleBillingReturn retry logic", () => {
beforeEach(() => {
jest.clearAllMocks();
jest.useFakeTimers();
setupHooks();
// Default: no billing-return params
mockSearchParams.get.mockReturnValue(null);
// Clear any activating state from prior tests
sessionStorage.clear();
});
afterEach(() => {
jest.useRealTimers();
jest.restoreAllMocks();
});
test("calls claimLicense once and refreshes on first-attempt success", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_test_123" : null
);
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
expect(mockClaimLicense).toHaveBeenCalledWith("cs_test_123");
});
expect(mockRouter.refresh).toHaveBeenCalled();
expect(mockRefreshBilling).toHaveBeenCalled();
// URL cleaned up after checkout return
expect(mockRouter.replace).toHaveBeenCalledWith("/admin/billing", {
scroll: false,
});
});
test("retries after first failure and succeeds on second attempt", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_retry_test" : null
);
mockClaimLicense
.mockRejectedValueOnce(new Error("License not ready yet"))
.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(2);
});
// On eventual success, router and billing should be refreshed
expect(mockRouter.refresh).toHaveBeenCalled();
expect(mockRefreshBilling).toHaveBeenCalled();
});
test("retries all 3 times then navigates to details even on total failure", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_all_fail" : null
);
// All 3 attempts fail
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(3);
});
// User stays on plans view with the activating banner
await waitFor(() => {
expect(screen.getByTestId("plans-view")).toBeInTheDocument();
});
// refreshBilling still fires so billing state is up to date
expect(mockRefreshBilling).toHaveBeenCalled();
// Failure is logged
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("Failed to sync license after billing return"),
expect.any(Error)
);
consoleSpy.mockRestore();
});
test("calls claimLicense without session_id on portal_return", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "portal_return" ? "true" : null
);
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
// No session_id for portal returns — called with undefined
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
});
expect(mockRefreshBilling).toHaveBeenCalled();
});
test("does not call claimLicense when no billing-return params present", async () => {
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(mockClaimLicense).not.toHaveBeenCalled();
});
test("shows activating banner and sets sessionStorage on 3x retry failure", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_all_fail" : null
);
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
});
expect(screen.getByTestId("activating-banner-text")).toHaveTextContent(
"Your license is still activating"
);
expect(
sessionStorage.getItem("billing_license_activating_until")
).not.toBeNull();
consoleSpy.mockRestore();
});
test("banner not rendered when no activating state", async () => {
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
});
test("banner shown on mount when sessionStorage key is set and not expired", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
// Flush React effects — banner is visible from lazy state init, no timer advancement needed
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
});
test("banner not shown on mount when sessionStorage key is expired", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() - 1000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
});
test("poll calls claimLicense after 15s and clears banner on success", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
// Poll attempt succeeds
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
// Flush effects — banner visible from lazy state init
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
// Advance past one poll interval (15s)
await act(async () => {
await jest.advanceTimersByTimeAsync(15_000);
});
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
expect(mockRefreshBilling).toHaveBeenCalled();
expect(mockRefreshLicense).toHaveBeenCalled();
expect(mockRouter.refresh).toHaveBeenCalled();
});
test("close button removes banner and clears sessionStorage", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
// Flush effects — banner visible from lazy state init
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
const closeButton = screen.getByTestId("activating-banner-close");
await act(async () => {
closeButton.click();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
});
});

View File

@@ -5,6 +5,7 @@ import { useSearchParams, useRouter } from "next/navigation";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Section } from "@/layouts/general-layouts";
import Button from "@/refresh-components/buttons/Button";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgArrowUpCircle, SvgWallet } from "@opal/icons";
import type { IconProps } from "@opal/types";
@@ -17,7 +18,6 @@ import {
} from "@/lib/billing";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { useUser } from "@/providers/UserProvider";
import Message from "@/refresh-components/messages/Message";
import PlansView from "./PlansView";
import CheckoutView from "./CheckoutView";
@@ -25,9 +25,6 @@ import BillingDetailsView from "./BillingDetailsView";
import LicenseActivationCard from "./LicenseActivationCard";
import "./billing.css";
// sessionStorage key: value is a unix-ms expiry timestamp
const BILLING_ACTIVATING_KEY = "billing_license_activating_until";
// ----------------------------------------------------------------------------
// Types
// ----------------------------------------------------------------------------
@@ -109,7 +106,6 @@ export default function BillingPage() {
const [transitionType, setTransitionType] = useState<
"expand" | "collapse" | "fade"
>("fade");
const [isActivating, setIsActivating] = useState<boolean>(false);
const {
data: billingData,
@@ -160,17 +156,6 @@ export default function BillingPage() {
view,
]);
// Read activating state from sessionStorage after mount (avoids SSR hydration mismatch)
useEffect(() => {
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
if (!raw) return;
if (Number(raw) > Date.now()) {
setIsActivating(true);
} else {
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
}
}, []);
// Show license activation card when there's a Stripe error
useEffect(() => {
if (hasStripeError && !showLicenseActivationInput) {
@@ -188,96 +173,24 @@ export default function BillingPage() {
router.replace("/admin/billing", { scroll: false });
let cancelled = false;
const handleBillingReturn = async () => {
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
// Retry up to 3 times with 2s backoff. The license may not be available
// immediately if the Stripe webhook hasn't finished processing yet
// (redirect and webhook fire nearly simultaneously).
let lastError: Error | null = null;
for (let attempt = 0; attempt < 3; attempt++) {
if (cancelled) return;
try {
// After checkout, exchange session_id for license; after portal, re-sync license
await claimLicense(sessionId ?? undefined);
if (cancelled) return;
refreshLicense();
// Refresh the page to update settings (including ee_features_enabled)
router.refresh();
// Navigate to billing details now that the license is active
changeView("details");
lastError = null;
break;
} catch (err) {
lastError = err instanceof Error ? err : new Error("Unknown error");
if (attempt < 2) {
await new Promise((resolve) => setTimeout(resolve, 2000));
}
}
}
if (cancelled) return;
if (lastError) {
console.error(
"Failed to sync license after billing return:",
lastError
);
// Show an activating banner on the plans view and keep retrying in the background.
sessionStorage.setItem(
BILLING_ACTIVATING_KEY,
String(Date.now() + 120_000)
);
setIsActivating(true);
changeView("plans");
try {
// After checkout, exchange session_id for license; after portal, re-sync license
await claimLicense(sessionId ?? undefined);
refreshLicense();
// Refresh the page to update settings (including ee_features_enabled)
router.refresh();
// Navigate to billing details now that the license is active
changeView("details");
} catch (error) {
console.error("Failed to sync license after billing return:", error);
}
}
if (!cancelled) refreshBilling();
refreshBilling();
};
handleBillingReturn();
return () => {
cancelled = true;
};
// changeView intentionally omitted: it only calls stable state setters and the
// effect runs at most once (when session_id/portal_return params are present).
}, [searchParams, router, refreshBilling, refreshLicense]); // eslint-disable-line react-hooks/exhaustive-deps
// Poll every 15s while activating, up to 2 minutes, to detect when the license arrives.
useEffect(() => {
if (!isActivating) return;
let requestInFlight = false;
const intervalId = setInterval(async () => {
if (requestInFlight) return;
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
if (!raw || Number(raw) <= Date.now()) {
// Expired — stop immediately without waiting for React cleanup
clearInterval(intervalId);
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
return;
}
requestInFlight = true;
try {
await claimLicense(undefined);
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
refreshLicense();
refreshBilling();
router.refresh();
changeView("details");
} catch (err) {
// License not ready yet — keep polling. Log so unexpected failures
// (network errors, 500s) are distinguishable from expected 404s.
console.debug("License activation poll: will retry", err);
} finally {
requestInFlight = false;
}
}, 15_000);
return () => clearInterval(intervalId);
}, [isActivating]); // eslint-disable-line react-hooks/exhaustive-deps
}, [searchParams, router, refreshBilling, refreshLicense]);
const handleRefresh = async () => {
await Promise.all([
@@ -474,22 +387,6 @@ export default function BillingPage() {
/>
<SettingsLayouts.Body>
<div className="flex flex-col items-center gap-6">
{isActivating && (
<Message
static
warning
large
text="Your license is still activating"
description="Your license is being processed. You'll be taken to billing details automatically once confirmed."
icon
close
onClose={() => {
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
}}
className="w-full"
/>
)}
{renderContent()}
{renderFooter()}
</div>

View File

@@ -7,6 +7,7 @@ import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import useSWR from "swr";
import { ThreeDotsLoader } from "@/components/Loading";
import * as SettingsLayouts from "@/layouts/settings-layouts";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { SvgLock } from "@opal/icons";

View File

@@ -1,11 +1,11 @@
"use client";
import { useState, useMemo, useEffect } from "react";
import { useState, useMemo } from "react";
import useSWR from "swr";
import { Text } from "@opal/components";
import { Select } from "@/refresh-components/cards";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
import {
@@ -17,17 +17,9 @@ import {
ImageGenerationConfigView,
setDefaultImageGenerationConfig,
unsetDefaultImageGenerationConfig,
deleteImageGenerationConfig,
} from "@/lib/configuration/imageConfigurationService";
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
import Message from "@/refresh-components/messages/Message";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { Button, Text } from "@opal/components";
import { SvgSlash, SvgUnplug } from "@opal/icons";
import { markdown } from "@opal/utils";
const NO_DEFAULT_VALUE = "__none__";
export default function ImageGenerationContent() {
const {
@@ -55,11 +47,6 @@ export default function ImageGenerationContent() {
);
const [editConfig, setEditConfig] =
useState<ImageGenerationConfigView | null>(null);
const [disconnectProvider, setDisconnectProvider] =
useState<ImageProvider | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const connectedProviderIds = useMemo(() => {
return new Set(configs.map((c) => c.image_provider_id));
@@ -128,29 +115,6 @@ export default function ImageGenerationContent() {
modal.toggle(true);
};
const handleDisconnect = async () => {
if (!disconnectProvider) return;
try {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
await setDefaultImageGenerationConfig(replacementProviderId);
}
await deleteImageGenerationConfig(disconnectProvider.image_provider_id);
toast.success(`${disconnectProvider.title} disconnected`);
refetchConfigs();
refetchProviders();
} catch (error) {
console.error("Failed to disconnect image generation provider:", error);
toast.error(
error instanceof Error ? error.message : "Failed to disconnect"
);
} finally {
setDisconnectProvider(null);
setReplacementProviderId(null);
}
};
const handleModalSuccess = () => {
toast.success("Provider configured successfully");
setEditConfig(null);
@@ -166,44 +130,12 @@ export default function ImageGenerationContent() {
);
}
// Compute replacement options when disconnecting an active provider
const isDisconnectingDefault =
disconnectProvider &&
defaultConfig?.image_provider_id === disconnectProvider.image_provider_id;
// Group connected replacement models by provider (excluding the model being disconnected)
const replacementGroups = useMemo(() => {
if (!disconnectProvider) return [];
return IMAGE_PROVIDER_GROUPS.map((group) => ({
...group,
providers: group.providers.filter(
(p) =>
p.image_provider_id !== disconnectProvider.image_provider_id &&
connectedProviderIds.has(p.image_provider_id)
),
})).filter((g) => g.providers.length > 0);
}, [disconnectProvider, connectedProviderIds]);
const needsReplacement = !!isDisconnectingDefault;
const hasReplacements = replacementGroups.length > 0;
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && !replacementProviderId && hasReplacements) {
const firstGroup = replacementGroups[0];
const firstModel = firstGroup?.providers[0];
if (firstModel) setReplacementProviderId(firstModel.image_provider_id);
}
}, [disconnectProvider]); // eslint-disable-line react-hooks/exhaustive-deps
return (
<>
<div className="flex flex-col gap-6">
{/* Section Header */}
<div className="flex flex-col gap-0.5">
<Text font="main-content-emphasis" color="text-05">
Image Generation Model
</Text>
<Text font="main-content-emphasis">Image Generation Model</Text>
<Text font="secondary-body" color="text-03">
Select a model to generate images in chat.
</Text>
@@ -241,11 +173,6 @@ export default function ImageGenerationContent() {
onSelect={() => handleSelect(provider)}
onDeselect={() => handleDeselect(provider)}
onEdit={() => handleEdit(provider)}
onDisconnect={
getStatus(provider) !== "disconnected"
? () => setDisconnectProvider(provider)
: undefined
}
/>
))}
</div>
@@ -253,108 +180,6 @@ export default function ImageGenerationContent() {
))}
</div>
{disconnectProvider && (
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectProvider.title}`}
description="This will remove the stored credentials for this provider."
onClose={() => {
setDisconnectProvider(null);
setReplacementProviderId(null);
}}
submit={
<Button
variant="danger"
onClick={() => void handleDisconnect()}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</Button>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" color="text-03">
{markdown(
`**${disconnectProvider.title}** is currently the default image generation model. Session history will be preserved.`
)}
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" color="text-04">
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => setReplacementProviderId(v)}
>
<InputSelect.Trigger placeholder="Select a replacement model" />
<InputSelect.Content>
{replacementGroups.map((group) => (
<InputSelect.Group key={group.name}>
<InputSelect.Label>{group.name}</InputSelect.Label>
{group.providers.map((p) => (
<InputSelect.Item
key={p.image_provider_id}
value={p.image_provider_id}
icon={() => (
<ProviderIcon
provider={p.provider_name}
size={16}
/>
)}
>
{p.title}
</InputSelect.Item>
))}
</InputSelect.Group>
))}
<InputSelect.Separator />
<InputSelect.Item
value={NO_DEFAULT_VALUE}
icon={SvgSlash}
>
<span>
<b>No Default</b>
<span className="text-text-03">
{" "}
(Disable Image Generation)
</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" color="text-03">
{markdown(
`**${disconnectProvider.title}** is currently the default image generation model.`
)}
</Text>
<Text as="p" color="text-03">
Connect another provider to continue using image generation.
</Text>
</>
)
) : (
<>
<Text as="p" color="text-03">
{markdown(
`**${disconnectProvider.title}** models will no longer be used to generate images.`
)}
</Text>
<Text as="p" color="text-03">
Session history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
)}
{activeProvider && (
<modal.Provider>
<ImageGenerationConnectionModal

View File

@@ -8,6 +8,7 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { SvgX } from "@opal/icons";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
function ModelConfigurationRow({

View File

@@ -23,7 +23,6 @@ import {
BedrockModelResponse,
LMStudioModelResponse,
LiteLLMProxyModelResponse,
BifrostModelResponse,
ModelConfiguration,
LLMProviderName,
BedrockFetchParams,
@@ -31,9 +30,8 @@ import {
LMStudioFetchParams,
OpenRouterFetchParams,
LiteLLMProxyFetchParams,
BifrostFetchParams,
} from "@/interfaces/llm";
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
import { SvgAws, SvgOpenrouter } from "@opal/icons";
// Aggregator providers that host models from multiple vendors
export const AGGREGATOR_PROVIDERS = new Set([
@@ -43,7 +41,6 @@ export const AGGREGATOR_PROVIDERS = new Set([
"ollama_chat",
"lm_studio",
"litellm_proxy",
"bifrost",
"vertex_ai",
]);
@@ -81,7 +78,6 @@ export const getProviderIcon = (
bedrock_converse: SvgAws,
openrouter: SvgOpenrouter,
litellm_proxy: LiteLLMIcon,
bifrost: SvgBifrost,
vertex_ai: GeminiIcon,
};
@@ -267,11 +263,8 @@ export const fetchOpenRouterModels = async (
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch (jsonError) {
console.warn(
"Failed to parse OpenRouter model fetch error response",
jsonError
);
} catch {
// ignore JSON parsing errors
}
return { models: [], error: errorMessage };
}
@@ -326,11 +319,8 @@ export const fetchLMStudioModels = async (
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch (jsonError) {
console.warn(
"Failed to parse LM Studio model fetch error response",
jsonError
);
} catch {
// ignore JSON parsing errors
}
return { models: [], error: errorMessage };
}
@@ -353,64 +343,6 @@ export const fetchLMStudioModels = async (
}
};
/**
* Fetches Bifrost models directly without any form state dependencies.
* Uses snake_case params to match API structure.
*/
export const fetchBifrostModels = async (
params: BifrostFetchParams
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
const apiBase = params.api_base;
if (!apiBase) {
return { models: [], error: "API Base is required" };
}
try {
const response = await fetch("/api/admin/llm/bifrost/available-models", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_base: apiBase,
api_key: params.api_key,
provider_name: params.provider_name,
}),
signal: params.signal,
});
if (!response.ok) {
let errorMessage = "Failed to fetch models";
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch (jsonError) {
console.warn(
"Failed to parse Bifrost model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
const data: BifrostModelResponse[] = await response.json();
const models: ModelConfiguration[] = data.map((modelData) => ({
name: modelData.name,
display_name: modelData.display_name,
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: modelData.supports_reasoning,
}));
return { models };
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
return { models: [], error: errorMessage };
}
};
/**
* Fetches LiteLLM Proxy models directly without any form state dependencies.
* Uses snake_case params to match API structure.
@@ -524,13 +456,6 @@ export const fetchModels = async (
provider_name: formValues.name,
signal,
});
case LLMProviderName.BIFROST:
return fetchBifrostModels({
api_base: formValues.api_base,
api_key: formValues.api_key,
provider_name: formValues.name,
signal,
});
default:
return { models: [], error: `Unknown provider: ${providerName}` };
}
@@ -544,7 +469,6 @@ export function canProviderFetchModels(providerName?: string) {
case LLMProviderName.LM_STUDIO:
case LLMProviderName.OPENROUTER:
case LLMProviderName.LITELLM_PROXY:
case LLMProviderName.BIFROST:
return true;
default:
return false;

View File

@@ -1,25 +1,33 @@
"use client";
import Image from "next/image";
import { useEffect, useMemo, useState, useReducer } from "react";
import { useMemo, useState, useReducer } from "react";
import { InfoIcon } from "@/components/icons/icons";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Select } from "@/refresh-components/cards";
import { Section } from "@/layouts/general-layouts";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Content } from "@opal/layouts";
import useSWR from "swr";
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
import { ThreeDotsLoader } from "@/components/Loading";
import { Callout } from "@/components/ui/callout";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { Disabled } from "@opal/core";
import { cn } from "@/lib/utils";
import { toast } from "@/hooks/useToast";
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
import { Button } from "@opal/components";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgCheckSquare,
SvgEdit,
SvgGlobe,
SvgOnyxLogo,
SvgX,
} from "@opal/icons";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
const route = ADMIN_ROUTES.WEB_SEARCH;
import {
SEARCH_PROVIDERS_URL,
SEARCH_PROVIDER_DETAILS,
@@ -51,10 +59,6 @@ import {
} from "@/app/admin/configuration/web-search/WebProviderModalReducer";
import { connectProviderFlow } from "@/app/admin/configuration/web-search/connectProviderFlow";
const NO_DEFAULT_VALUE = "__none__";
const route = ADMIN_ROUTES.WEB_SEARCH;
interface WebSearchProviderView {
id: number;
name: string;
@@ -73,151 +77,27 @@ interface WebContentProviderView {
has_api_key: boolean;
}
interface DisconnectTargetState {
id: number;
label: string;
category: "search" | "content";
providerType: string;
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
isHovered: boolean;
onMouseEnter: () => void;
onMouseLeave: () => void;
children: React.ReactNode;
}
function WebSearchDisconnectModal({
disconnectTarget,
searchProviders,
contentProviders,
replacementProviderId,
onReplacementChange,
onClose,
onDisconnect,
}: {
disconnectTarget: DisconnectTargetState;
searchProviders: WebSearchProviderView[];
contentProviders: WebContentProviderView[];
replacementProviderId: string | null;
onReplacementChange: (id: string | null) => void;
onClose: () => void;
onDisconnect: () => void;
}) {
const isSearch = disconnectTarget.category === "search";
// Determine if the target is currently the active/selected provider
const isActive = isSearch
? searchProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
false
: contentProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
false;
// Find other configured providers as replacements
const replacementOptions = isSearch
? searchProviders.filter(
(p) => p.id !== disconnectTarget.id && p.id > 0 && p.has_api_key
)
: contentProviders.filter(
(p) =>
p.id !== disconnectTarget.id &&
p.provider_type !== "onyx_web_crawler" &&
p.id > 0 &&
p.has_api_key
);
const needsReplacement = isActive;
const hasReplacements = replacementOptions.length > 0;
const getLabel = (p: { name: string; provider_type: string }) => {
if (isSearch) {
const details =
SEARCH_PROVIDER_DETAILS[p.provider_type as WebSearchProviderType];
return details?.label ?? p.name ?? p.provider_type;
}
const details = CONTENT_PROVIDER_DETAILS[p.provider_type];
return details?.label ?? p.name ?? p.provider_type;
};
const categoryLabel = isSearch ? "search engine" : "web crawler";
const featureLabel = isSearch ? "web search" : "web crawling";
const disableLabel = isSearch ? "Disable Web Search" : "Disable Web Crawling";
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && hasReplacements && !replacementProviderId) {
const first = replacementOptions[0];
if (first) onReplacementChange(String(first.id));
}
}, []); // eslint-disable-line react-hooks/exhaustive-deps
function HoverIconButton({
isHovered,
onMouseEnter,
onMouseLeave,
children,
...buttonProps
}: HoverIconButtonProps) {
return (
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectTarget.label}`}
description="This will remove the stored credentials for this provider."
onClose={onClose}
submit={
<Button
variant="danger"
onClick={onDisconnect}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</Button>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" text03>
<b>{disconnectTarget.label}</b> is currently the active{" "}
{categoryLabel}. Search history will be preserved.
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" secondaryBody text03>
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => onReplacementChange(v)}
>
<InputSelect.Trigger placeholder="Select a replacement provider" />
<InputSelect.Content>
{replacementOptions.map((p) => (
<InputSelect.Item key={p.id} value={String(p.id)}>
{getLabel(p)}
</InputSelect.Item>
))}
<InputSelect.Separator />
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
<span>
<b>No Default</b>
<span className="text-text-03"> ({disableLabel})</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" text03>
<b>{disconnectTarget.label}</b> is currently the active{" "}
{categoryLabel}.
</Text>
<Text as="p" text03>
Connect another provider to continue using {featureLabel}.
</Text>
</>
)
) : (
<>
<Text as="p" text03>
{isSearch ? "Web search" : "Web crawling"} will no longer be routed
through <b>{disconnectTarget.label}</b>.
</Text>
<Text as="p" text03>
Search history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
{/* TODO(@raunakab): migrate to opal Button once HoverIconButtonProps typing is resolved */}
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
{children}
</Button>
</div>
);
}
@@ -226,11 +106,6 @@ export default function Page() {
WebProviderModalReducer,
initialWebProviderModalState
);
const [disconnectTarget, setDisconnectTarget] =
useState<DisconnectTargetState | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const [contentModal, dispatchContentModal] = useReducer(
WebProviderModalReducer,
initialWebProviderModalState
@@ -239,6 +114,8 @@ export default function Page() {
const [contentActivationError, setContentActivationError] = useState<
string | null
>(null);
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
const {
data: searchProvidersData,
error: searchProvidersError,
@@ -957,67 +834,6 @@ export default function Page() {
});
};
const handleDisconnectProvider = async () => {
if (!disconnectTarget) return;
const { id, category } = disconnectTarget;
try {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
const repId = Number(replacementProviderId);
const activateEndpoint =
category === "search"
? `/api/admin/web-search/search-providers/${repId}/activate`
: `/api/admin/web-search/content-providers/${repId}/activate`;
const activateResp = await fetch(activateEndpoint, {
method: "POST",
headers: { "Content-Type": "application/json" },
});
if (!activateResp.ok) {
const errorBody = await activateResp.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to activate replacement provider."
);
}
}
const response = await fetch(
`/api/admin/web-search/${category}-providers/${id}`,
{ method: "DELETE" }
);
if (!response.ok) {
const errorBody = await response.json().catch((parseErr) => {
console.error("Failed to parse disconnect error response:", parseErr);
return {};
});
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to disconnect provider."
);
}
toast.success(`${disconnectTarget.label} disconnected`);
await mutateSearchProviders();
await mutateContentProviders();
} catch (error) {
console.error("Failed to disconnect web search provider:", error);
const message =
error instanceof Error ? error.message : "Unexpected error occurred.";
if (category === "search") {
setActivationError(message);
} else {
setContentActivationError(message);
}
} finally {
setDisconnectTarget(null);
setReplacementProviderId(null);
}
};
return (
<>
<SettingsLayouts.Root>
@@ -1079,79 +895,149 @@ export default function Page() {
provider
);
const isActive = provider?.is_active ?? false;
const isHighlighted = isActive;
const providerId = provider?.id;
const canOpenModal =
isBuiltInSearchProviderType(providerType);
const status: "disconnected" | "connected" | "selected" =
!isConfigured
? "disconnected"
: isActive
? "selected"
: "connected";
return (
<Select
key={`${key}-${providerType}`}
icon={() =>
logoSrc ? (
<Image
src={logoSrc}
alt={`${label} logo`}
width={16}
height={16}
/>
) : (
<SvgGlobe size={16} />
)
}
title={label}
description={subtitle}
status={status}
onConnect={
canOpenModal
const buttonState = (() => {
if (!provider || !isConfigured) {
return {
label: "Connect",
disabled: false,
icon: "arrow" as const,
onClick: canOpenModal
? () => {
openSearchModal(providerType, provider);
setActivationError(null);
}
: undefined
}
onSelect={
providerId
? () => {
void handleActivateSearchProvider(providerId);
}
: undefined
}
onDeselect={
providerId
: undefined,
};
}
if (isActive) {
return {
label: "Current Default",
disabled: false,
icon: "check" as const,
onClick: providerId
? () => {
void handleDeactivateSearchProvider(providerId);
}
: undefined
}
onEdit={
isConfigured && canOpenModal
? () => {
: undefined,
};
}
return {
label: "Set as Default",
disabled: false,
icon: "arrow-circle" as const,
onClick: providerId
? () => {
void handleActivateSearchProvider(providerId);
}
: undefined,
};
})();
const buttonKey = `search-${key}-${providerType}`;
const isButtonHovered = hoveredButtonKey === buttonKey;
const isCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleCardClick = () => {
if (isCardClickable) {
buttonState.onClick?.();
}
};
return (
<div
key={`${key}-${providerType}`}
onClick={isCardClickable ? handleCardClick : undefined}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
isHighlighted
? "border-action-link-05"
: "border-border-01",
isCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 px-2 py-1">
{renderLogo({
logoSrc,
alt: `${label} logo`,
size: 16,
isHighlighted,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
</div>
<div className="flex items-center justify-end gap-2">
{isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={() => {
if (!canOpenModal) return;
openSearchModal(
providerType as WebSearchProviderType,
provider
);
}}
aria-label={`Edit ${label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isButtonHovered}
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<Disabled
disabled={
buttonState.disabled || !buttonState.onClick
}
: undefined
}
onDisconnect={
isConfigured && provider && provider.id > 0
? () =>
setDisconnectTarget({
id: provider.id,
label,
category: "search",
providerType,
})
: undefined
}
/>
>
<OpalButton
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</OpalButton>
</Disabled>
)}
</div>
</div>
);
}
)}
@@ -1191,81 +1077,161 @@ export default function Page() {
const isCurrentCrawler =
provider.provider_type === currentContentProviderType;
const status: "disconnected" | "connected" | "selected" =
!isConfigured
? "disconnected"
: isCurrentCrawler
? "selected"
: "connected";
const buttonState = (() => {
if (!isConfigured) {
return {
label: "Connect",
icon: "arrow" as const,
disabled: false,
onClick: () => {
openContentModal(provider.provider_type, provider);
setContentActivationError(null);
},
};
}
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
if (isCurrentCrawler) {
return {
label: "Current Crawler",
icon: "check" as const,
disabled: false,
onClick: () => {
void handleDeactivateContentProvider(
providerId,
provider.provider_type
);
},
};
}
const contentLogoSrc =
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.logoSrc;
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
return {
label: "Set as Default",
icon: "arrow-circle" as const,
disabled: !canActivate,
onClick: canActivate
? () => {
void handleActivateContentProvider(provider);
}
: undefined,
};
})();
const contentButtonKey = `content-${provider.provider_type}-${provider.id}`;
const isContentButtonHovered =
hoveredButtonKey === contentButtonKey;
const isContentCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleContentCardClick = () => {
if (isContentCardClickable) {
buttonState.onClick?.();
}
};
return (
<Select
<div
key={`${provider.provider_type}-${provider.id}`}
icon={() =>
contentLogoSrc ? (
<Image
src={contentLogoSrc}
alt={`${label} logo`}
width={16}
height={16}
/>
) : provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo size={16} />
onClick={
isContentCardClickable
? handleContentCardClick
: undefined
}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
isCurrentCrawler
? "border-action-link-05"
: "border-border-01",
isContentCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 px-2 py-1">
{renderLogo({
logoSrc:
CONTENT_PROVIDER_DETAILS[provider.provider_type]
?.logoSrc,
alt: `${label} logo`,
fallback:
provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo size={16} />
) : undefined,
size: 16,
isHighlighted: isCurrentCrawler,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
</div>
<div className="flex items-center justify-end gap-2">
{provider.provider_type !== "onyx_web_crawler" &&
isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={() => {
openContentModal(
provider.provider_type,
provider
);
}}
aria-label={`Edit ${label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isContentButtonHovered}
onMouseEnter={() =>
setHoveredButtonKey(contentButtonKey)
}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<SvgGlobe size={16} />
)
}
title={label}
description={subtitle}
status={status}
selectedLabel="Current Crawler"
onConnect={() => {
openContentModal(provider.provider_type, provider);
setContentActivationError(null);
}}
onSelect={
canActivate
? () => {
void handleActivateContentProvider(provider);
<Disabled
disabled={
buttonState.disabled || !buttonState.onClick
}
: undefined
}
onDeselect={() => {
void handleDeactivateContentProvider(
providerId,
provider.provider_type
);
}}
onEdit={
provider.provider_type !== "onyx_web_crawler" &&
isConfigured
? () => {
openContentModal(provider.provider_type, provider);
}
: undefined
}
onDisconnect={
provider.provider_type !== "onyx_web_crawler" &&
isConfigured &&
provider.id > 0
? () =>
setDisconnectTarget({
id: provider.id,
label,
category: "content",
providerType: provider.provider_type,
})
: undefined
}
/>
>
<OpalButton
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</OpalButton>
</Disabled>
)}
</div>
</div>
);
})}
</div>
@@ -1273,21 +1239,6 @@ export default function Page() {
</SettingsLayouts.Body>
</SettingsLayouts.Root>
{disconnectTarget && (
<WebSearchDisconnectModal
disconnectTarget={disconnectTarget}
searchProviders={searchProviders}
contentProviders={combinedContentProviders}
replacementProviderId={replacementProviderId}
onReplacementChange={setReplacementProviderId}
onClose={() => {
setDisconnectTarget(null);
setReplacementProviderId(null);
}}
onDisconnect={() => void handleDisconnectProvider()}
/>
)}
<WebProviderSetupModal
isOpen={selectedProviderType !== null}
onClose={() => {

View File

@@ -4,6 +4,7 @@ import { useState } from "react";
import { ValidSources } from "@/lib/types";
import { Section } from "@/layouts/general-layouts";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Button } from "@opal/components";
import Separator from "@/refresh-components/Separator";

View File

@@ -10,7 +10,7 @@ import {
import { IndexAttemptError } from "./types";
import { localizeAndPrettify } from "@/lib/time";
import Button from "@/refresh-components/buttons/Button";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import { PageSelector } from "@/components/PageSelector";
import { useCallback, useEffect, useRef, useState, useMemo } from "react";
import { SvgAlertTriangle } from "@opal/icons";
@@ -113,11 +113,11 @@ export default function IndexAttemptErrorsModal({
<Modal.Body height="full">
{!isResolvingErrors && (
<div className="flex flex-col gap-2 flex-shrink-0">
<Text as="p">
<Text as="p" color="text-05">
Below are the errors encountered during indexing. Each row
represents a failed document or entity.
</Text>
<Text as="p">
<Text as="p" color="text-05">
Click the button below to kick off a full re-index to try and
resolve these errors. This full re-index may take much longer
than a normal update.

View File

@@ -21,6 +21,7 @@ import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ThreeDotsLoader } from "@/components/Loading";
import Modal from "@/refresh-components/Modal";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import {
SvgCheck,

View File

@@ -6,6 +6,7 @@ import { useState } from "react";
import { toast } from "@/hooks/useToast";
import { triggerIndexing } from "@/app/admin/connector/[ccPairId]/lib";
import Modal from "@/refresh-components/Modal";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Separator from "@/refresh-components/Separator";
import { SvgRefreshCw } from "@opal/icons";

View File

@@ -7,6 +7,7 @@ import { SourceIcon } from "@/components/SourceIcon";
import { CCPairStatus, PermissionSyncStatus } from "@/components/Status";
import { toast } from "@/hooks/useToast";
import CredentialSection from "@/components/credentials/CredentialSection";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import {
updateConnectorCredentialPairName,

View File

@@ -59,6 +59,7 @@ import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { deleteConnector } from "@/lib/connector";
import ConnectorDocsLink from "@/components/admin/connectors/ConnectorDocsLink";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgKey, SvgAlertCircle } from "@opal/icons";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";

View File

@@ -19,7 +19,6 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
import { Credential } from "@/lib/connectors/credentials";
import { useFederatedConnectors } from "@/lib/hooks";
import Text from "@/refresh-components/texts/Text";
import { useToastFromQuery } from "@/hooks/useToast";
export default function ConnectorWrapper({

View File

@@ -15,7 +15,7 @@ import * as InputLayouts from "@/layouts/input-layouts";
import { Content } from "@opal/layouts";
import CheckboxField from "@/refresh-components/form/LabeledCheckboxField";
import InputTextAreaField from "@/refresh-components/form/InputTextAreaField";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
// Define a general type for form values
type FormValues = Record<string, any>;
@@ -57,7 +57,7 @@ const TabsField: FC<TabsFieldProps> = ({
{/* Ensure there's at least one tab before rendering */}
{tabField.tabs.length === 0 ? (
<Text text03 secondaryBody>
<Text color="text-03" font="secondary-body">
No tabs to display.
</Text>
) : (
@@ -253,7 +253,7 @@ export const RenderField: FC<RenderFieldProps> = ({
)
) : field.type === "string_tab" ? (
<GeneralLayouts.Section>
<Text text03 secondaryBody>
<Text color="text-03" font="secondary-body">
{description}
</Text>
</GeneralLayouts.Section>

View File

@@ -2,6 +2,7 @@
import { useState } from "react";
import { Section } from "@/layouts/general-layouts";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Card from "@/refresh-components/cards/Card";
import { Button } from "@opal/components";

View File

@@ -11,7 +11,7 @@ import {
import Switch from "@/refresh-components/inputs/Switch";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import EmptyMessage from "@/refresh-components/EmptyMessage";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import { Section } from "@/layouts/general-layouts";
import {
DiscordChannelConfig,
@@ -95,9 +95,7 @@ export function DiscordChannelsTable({
width="fit"
>
<ChannelIcon width={16} height={16} />
<Text text04 mainUiBody>
{channel.channel_name}
</Text>
<Text>{channel.channel_name}</Text>
</Section>
</TableCell>
<TableCell>

View File

@@ -8,11 +8,10 @@ import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import { ContentAction } from "@opal/layouts";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/refresh-components/texts/Text";
import { Button, Text } from "@opal/components";
import Card from "@/refresh-components/cards/Card";
import { Callout } from "@/components/ui/callout";
import Message from "@/refresh-components/messages/Message";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { SvgServer } from "@opal/icons";
import InputSelect from "@/refresh-components/inputs/InputSelect";
@@ -121,7 +120,7 @@ function GuildDetailContent({
/>
{!isRegistered ? (
<Text text03 secondaryBody>
<Text color="text-03" font="secondary-body">
Channel configuration will be available after the server is
registered.
</Text>

View File

@@ -6,7 +6,7 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import Modal from "@/refresh-components/Modal";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
@@ -76,7 +76,7 @@ function DiscordBotContent() {
description="This key will only be shown once!"
/>
<Modal.Body>
<Text text04 mainUiBody>
<Text>
Copy the command and send it from any text channel in your server!
</Text>
<Card variant="secondary">
@@ -85,8 +85,8 @@ function DiscordBotContent() {
justifyContent="between"
alignItems="center"
>
<Text text03 secondaryMono>
!register {registrationKey}
<Text color="text-03" font="secondary-mono">
{`!register ${registrationKey}`}
</Text>
<CopyIconButton
getCopyText={() => `!register ${registrationKey}`}
@@ -103,7 +103,7 @@ function DiscordBotContent() {
justifyContent="between"
alignItems="center"
>
<Text mainContentEmphasis text05>
<Text font="main-content-emphasis" color="text-05">
Server Configurations
</Text>
<CreateButton

View File

@@ -9,7 +9,7 @@ const route = ADMIN_ROUTES.INDEX_MIGRATION;
import Card from "@/refresh-components/cards/Card";
import { Content, ContentAction } from "@opal/layouts";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import Button from "@/refresh-components/buttons/Button";
import { errorHandlingFetcher } from "@/lib/fetcher";
@@ -38,10 +38,10 @@ function MigrationStatusSection() {
if (isLoading) {
return (
<Card>
<Text headingH3>Migration Status</Text>
<Text mainUiBody text03>
Loading...
<Text font="heading-h3" color="text-05">
Migration Status
</Text>
<Text color="text-03">Loading...</Text>
</Card>
);
}
@@ -49,10 +49,10 @@ function MigrationStatusSection() {
if (error) {
return (
<Card>
<Text headingH3>Migration Status</Text>
<Text mainUiBody text03>
Failed to load migration status.
<Text font="heading-h3" color="text-05">
Migration Status
</Text>
<Text color="text-03">Failed to load migration status.</Text>
</Card>
);
}
@@ -73,14 +73,16 @@ function MigrationStatusSection() {
return (
<Card>
<Text headingH3>Migration Status</Text>
<Text font="heading-h3" color="text-05">
Migration Status
</Text>
<ContentAction
title="Started"
sizePreset="main-ui"
variant="section"
rightChildren={
<Text mainUiBody>
<Text color="text-05">
{hasStarted ? formatTimestamp(data.created_at!) : "Not started"}
</Text>
}
@@ -91,7 +93,7 @@ function MigrationStatusSection() {
sizePreset="main-ui"
variant="section"
rightChildren={
<Text mainUiBody>
<Text color="text-05">
{progressPercentage !== null
? `${totalChunksMigrated} (approx. progress ${Math.round(
progressPercentage
@@ -106,7 +108,7 @@ function MigrationStatusSection() {
sizePreset="main-ui"
variant="section"
rightChildren={
<Text mainUiBody>
<Text color="text-05">
{hasCompleted
? formatTimestamp(data.migration_completed_at!)
: hasStarted
@@ -159,10 +161,10 @@ function RetrievalSourceSection() {
if (isLoading) {
return (
<Card>
<Text headingH3>Retrieval Source</Text>
<Text mainUiBody text03>
Loading...
<Text font="heading-h3" color="text-05">
Retrieval Source
</Text>
<Text color="text-03">Loading...</Text>
</Card>
);
}
@@ -170,10 +172,10 @@ function RetrievalSourceSection() {
if (error) {
return (
<Card>
<Text headingH3>Retrieval Source</Text>
<Text mainUiBody text03>
Failed to load retrieval settings.
<Text font="heading-h3" color="text-05">
Retrieval Source
</Text>
<Text color="text-03">Failed to load retrieval settings.</Text>
</Card>
);
}

View File

@@ -3,6 +3,7 @@
import React, { useRef, useState } from "react";
import Modal from "@/refresh-components/Modal";
import { Callout } from "@/components/ui/callout";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Separator from "@/refresh-components/Separator";
import Button from "@/refresh-components/buttons/Button";

View File

@@ -1,4 +1,5 @@
import Modal from "@/refresh-components/Modal";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Button } from "@opal/components";
import { Callout } from "@/components/ui/callout";

View File

@@ -1,5 +1,6 @@
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgAlertTriangle } from "@opal/icons";
export interface InstantSwitchConfirmModalProps {

View File

@@ -1,4 +1,5 @@
import Modal from "@/refresh-components/Modal";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Callout } from "@/components/ui/callout";
import { Button } from "@opal/components";

View File

@@ -1,4 +1,5 @@
import React, { useRef, useState } from "react";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { Callout } from "@/components/ui/callout";
import { Button } from "@opal/components";

View File

@@ -1,5 +1,6 @@
import Modal from "@/refresh-components/Modal";
import { Button } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { CloudEmbeddingModel } from "@/components/embedding/interfaces";
import { SvgServer } from "@opal/icons";

View File

@@ -4,6 +4,7 @@ import { toast } from "@/hooks/useToast";
import EmbeddingModelSelection from "../EmbeddingModelSelectionForm";
import { useCallback, useEffect, useMemo, useState, useRef } from "react";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";

View File

@@ -8,6 +8,7 @@ import { ValidSources } from "@/lib/types";
import { FaCircleQuestion } from "react-icons/fa6";
import { CheckmarkIcon } from "@/components/icons/icons";
import { Button } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { cn } from "@/lib/utils";

View File

@@ -28,6 +28,7 @@ import Title from "@/components/ui/title";
import { redirect } from "next/navigation";
import { useIsKGExposed } from "@/app/admin/kg/utils";
import KGEntityTypes from "@/app/admin/kg/KGEntityTypes";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { SvgSettings } from "@opal/icons";

View File

@@ -1,9 +1,8 @@
import { SvgDownload, SvgKey, SvgRefreshCw } from "@opal/icons";
import { Interactive } from "@opal/core";
import { Section } from "@/layouts/general-layouts";
import { Button } from "@opal/components";
import { Button, Text } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
@@ -63,7 +62,7 @@ export default function ScimModal({
}
>
<Section alignItems="start" gap={0.5}>
<Text as="p" text03>
<Text as="p" color="text-03">
Your current SCIM token will be revoked and a new token will be
generated. You will need to update the token on your identity
provider before SCIM provisioning will resume.

View File

@@ -4,6 +4,7 @@ import { Section } from "@/layouts/general-layouts";
import Card from "@/refresh-components/cards/Card";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import Separator from "@/refresh-components/Separator";
import { timeAgo } from "@/lib/time";

View File

@@ -7,7 +7,7 @@ import { toast } from "@/hooks/useToast";
import { useScimToken } from "@/hooks/useScimToken";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import { ThreeDotsLoader } from "@/components/Loading";
import type { ScimTokenCreatedResponse, ScimModalView } from "./interfaces";
@@ -43,7 +43,7 @@ function ScimContent() {
if (tokenError && !is404) {
return (
<Text as="p" text03>
<Text as="p" color="text-03">
Failed to load SCIM token status.
</Text>
);

View File

@@ -1,5 +1,6 @@
"use client";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";

View File

@@ -3,6 +3,7 @@
import Modal from "@/refresh-components/Modal";
import { SettingsContext } from "@/providers/SettingsProvider";
import { Button } from "@opal/components";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { FormField } from "@/refresh-components/form/FormField";
import Checkbox from "@/refresh-components/inputs/Checkbox";

View File

@@ -6,7 +6,7 @@ import {
GREETING_MESSAGES,
} from "@/lib/chat/greetingMessages";
import AgentAvatar from "@/refresh-components/avatars/AgentAvatar";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { useState, useEffect } from "react";
import { useSettingsContext } from "@/providers/SettingsProvider";
@@ -41,7 +41,7 @@ export default function WelcomeMessage({
content = (
<div data-testid="onyx-logo" className="flex flex-row items-center gap-4">
<Logo folded size={32} />
<Text as="p" headingH2>
<Text as="p" font="heading-h2" color="text-05">
{greeting}
</Text>
</div>
@@ -54,7 +54,7 @@ export default function WelcomeMessage({
className="flex flex-row items-center gap-3"
>
<AgentAvatar agent={agent} size={36} />
<Text as="p" headingH2>
<Text as="p" font="heading-h2" color="text-05">
{agent.name}
</Text>
</div>

View File

@@ -8,6 +8,7 @@ import { ChatSession } from "@/app/app/interfaces";
import AgentAvatar from "@/refresh-components/avatars/AgentAvatar";
import { useAgents } from "@/hooks/useAgents";
import { formatRelativeTime } from "./project_utils";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { UNNAMED_CHAT } from "@/lib/constants";

View File

@@ -12,6 +12,7 @@ import { Button } from "@opal/components";
import AddInstructionModal from "@/components/modals/AddInstructionModal";
import UserFilesModal from "@/components/modals/UserFilesModal";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { FileCard, FileCardSkeleton } from "@/sections/cards/FileCard";

View File

@@ -1,5 +1,5 @@
import { cn } from "@/lib/utils";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import React, { useState, ReactNode, useCallback, useMemo, memo } from "react";
import { SvgCheck, SvgCode, SvgCopy } from "@opal/icons";
@@ -48,14 +48,14 @@ export const CodeBlock = memo(function CodeBlock({
{copied ? (
<div className="flex items-center space-x-2">
<SvgCheck height={14} width={14} stroke="currentColor" />
<Text as="p" secondaryMono>
<Text as="p" font="secondary-mono" color="text-05">
Copied!
</Text>
</div>
) : (
<div className="flex items-center space-x-2">
<SvgCopy height={14} width={14} stroke="currentColor" />
<Text as="p" secondaryMono>
<Text as="p" font="secondary-mono" color="text-05">
Copy
</Text>
</div>
@@ -131,7 +131,9 @@ export const CodeBlock = memo(function CodeBlock({
stroke="currentColor"
className="my-auto"
/>
<Text secondaryMono>{language}</Text>
<Text font="secondary-mono" color="text-05">
{language}
</Text>
{codeText && <CopyButton />}
</div>
)}

View File

@@ -4,6 +4,7 @@ import React, { useEffect, useMemo, useRef, useState } from "react";
import { FileDescriptor } from "@/app/app/interfaces";
import "katex/dist/katex.min.css";
import MessageSwitcher from "@/app/app/message/MessageSwitcher";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import useScreenSize from "@/hooks/useScreenSize";

View File

@@ -14,6 +14,7 @@ import { SubQuestionDetail, CitationMap } from "../interfaces";
import { ValidSources } from "@/lib/types";
import { ProjectFile } from "../projects/projectsService";
import { BlinkingBar } from "./BlinkingBar";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import SourceTag from "@/refresh-components/buttons/source-tag/SourceTag";
import {

View File

@@ -1,5 +1,6 @@
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import { SvgChevronLeft, SvgChevronRight } from "@opal/icons";
const DISABLED_MESSAGE = "Wait for agent message to complete";

View File

@@ -20,7 +20,7 @@ import { usePacedTurnGroups } from "@/app/app/message/messageComponents/timeline
import MessageToolbar from "@/app/app/message/messageComponents/MessageToolbar";
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
import { Message } from "@/app/app/interfaces";
import Text from "@/refresh-components/texts/Text";
import { Text } from "@opal/components";
import { AgentTimeline } from "@/app/app/message/messageComponents/timeline/AgentTimeline";
import { useVoiceMode } from "@/providers/VoiceModeProvider";
import { getTextContent } from "@/app/app/services/packetUtils";
@@ -319,7 +319,7 @@ const AgentMessage = React.memo(function AgentMessage({
{/* Show stopped message when user cancelled and no display content */}
{pacedDisplayGroups.length === 0 &&
stopReason === StopReason.USER_CANCELLED && (
<Text as="p" secondaryBody text04>
<Text as="p" font="secondary-body">
User has stopped generation
</Text>
)}

View File

@@ -10,6 +10,7 @@ import {
} from "../../../services/streamingModels";
import { MessageRenderer, RenderType } from "../interfaces";
import { buildImgUrl } from "../../../components/files/images/utils";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import {
SvgActions,

View File

@@ -1,4 +1,5 @@
import React, { useEffect, useMemo, useRef, useState } from "react";
// TODO(@raunakab): migrate this `refresh-components/Text` to `@opal/components` Text
import Text from "@/refresh-components/texts/Text";
import {

Some files were not shown because too many files have changed in this diff Show More