Compare commits

..

6 Commits

82 changed files with 2987 additions and 981 deletions

View File

@@ -0,0 +1,37 @@
"""add cache_store table
Revision ID: 2664261bfaab
Revises: 4a1e4b1c89d2
Create Date: 2026-02-27 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2664261bfaab"
down_revision = "4a1e4b1c89d2"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"cache_store",
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.LargeBinary(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("key"),
)
op.create_index(
"ix_cache_store_expires",
"cache_store",
["expires_at"],
postgresql_where=sa.text("expires_at IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index("ix_cache_store_expires", table_name="cache_store")
op.drop_table("cache_store")

View File

@@ -0,0 +1,34 @@
"""make scim_user_mapping.external_id nullable
Revision ID: a3b8d9e2f1c4
Revises: 2664261bfaab
Create Date: 2026-03-02
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "a3b8d9e2f1c4"
down_revision = "2664261bfaab"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=True,
)
def downgrade() -> None:
# Delete any rows where external_id is NULL before re-applying NOT NULL
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=False,
)

View File

@@ -126,12 +126,16 @@ class ScimDAL(DAL):
def create_user_mapping(
self,
external_id: str,
external_id: str | None,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
"""Create a SCIM mapping for a user.
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
allows this). The mapping still marks the user as SCIM-managed.
"""
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
@@ -270,8 +274,13 @@ class ScimDAL(DAL):
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
# unless they were explicitly linked via SCIM provisioning.
query = (
select(User)
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
)
if scim_filter:
@@ -321,34 +330,37 @@ class ScimDAL(DAL):
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Create, update, or delete the external ID mapping for a user.
"""Sync the SCIM mapping for a user.
If a mapping already exists, its fields are updated (including
setting ``external_id`` to ``None`` when the IdP omits it).
If no mapping exists and ``new_external_id`` is provided, a new
mapping is created. A mapping is never deleted here — SCIM-managed
users must retain their mapping to remain visible in ``GET /Users``.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
elif mapping:
self.delete_user_mapping(mapping.id)
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
elif new_external_id:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
def _get_user_mappings_batch(
self, user_ids: list[UUID]

View File

@@ -423,15 +423,63 @@ def create_user(
email = user_resource.userName.strip()
# Enforce seat limit
# Check for existing user — if they exist but aren't SCIM-managed yet,
# link them to the IdP rather than rejecting with 409.
external_id: str | None = user_resource.externalId
scim_username: str = user_resource.userName.strip()
fields: ScimMappingFields = _fields_from_resource(user_resource)
existing_user = dal.get_user_by_email(email)
if existing_user:
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
if existing_mapping:
return _scim_error_response(409, f"User with email {email} already exists")
# Adopt pre-existing user into SCIM management.
# Reactivating a deactivated user consumes a seat, so enforce the
# seat limit the same way replace_user does.
if user_resource.active and not existing_user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
existing_user,
is_active=user_resource.active,
**({"personal_name": personal_name} if personal_name else {}),
)
try:
dal.create_user_mapping(
external_id=external_id,
user_id=existing_user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
return _scim_resource_response(
provider.build_user_resource(
existing_user,
external_id,
scim_username=scim_username,
fields=fields,
),
status_code=201,
)
# Only enforce seat limit for net-new users — adopting a pre-existing
# user doesn't consume a new seat.
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
@@ -449,21 +497,21 @@ def create_user(
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Create SCIM mapping when externalId is provided — this is how the IdP
# correlates this user on subsequent requests. Per RFC 7643, externalId
# is optional and assigned by the provisioning client.
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
if external_id:
# Always create a SCIM mapping so that the user is marked as
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
try:
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
return _scim_resource_response(
provider.build_user_resource(

View File

@@ -170,7 +170,10 @@ class ScimProvider(ABC):
formatted=user.personal_name or "",
)
if not user.personal_name:
return ScimName(givenName="", familyName="", formatted="")
# Derive a reasonable name from the email so that SCIM spec tests
# see non-empty givenName / familyName for every user resource.
local = user.email.split("@")[0] if user.email else ""
return ScimName(givenName=local, familyName="", formatted=local)
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],

View File

@@ -520,6 +520,7 @@ def process_user_file_impl(
task_logger.exception(
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
raise
finally:
if file_lock is not None and file_lock.owned():
file_lock.release()
@@ -675,6 +676,7 @@ def delete_user_file_impl(
task_logger.exception(
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
raise
finally:
if file_lock is not None and file_lock.owned():
file_lock.release()
@@ -849,6 +851,7 @@ def project_sync_user_file_impl(
task_logger.exception(
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
)
raise
finally:
if file_lock is not None and file_lock.owned():
file_lock.release()

View File

@@ -59,6 +59,12 @@ def _run_auto_llm_update() -> None:
sync_llm_models_from_github(db_session)
def _run_cache_cleanup() -> None:
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
cleanup_expired_cache_entries()
def _run_scheduled_eval() -> None:
from onyx.configs.app_configs import BRAINTRUST_API_KEY
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
@@ -100,12 +106,26 @@ def _run_scheduled_eval() -> None:
)
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
tasks: list[_PeriodicTaskDef] = []
if CACHE_BACKEND == CacheBackendType.POSTGRES:
tasks.append(
_PeriodicTaskDef(
name="cache-cleanup",
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
run_fn=_run_cache_cleanup,
)
)
if AUTO_LLM_CONFIG_URL:
tasks.append(
_PeriodicTaskDef(

View File

@@ -75,31 +75,41 @@ def _claim_next_processing_file(db_session: Session) -> UUID | None:
return file_id
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
def _claim_next_deleting_file(
db_session: Session,
exclude_ids: set[UUID] | None = None,
) -> UUID | None:
"""Claim the next DELETING file.
No status transition needed — the impl deletes the row on success.
The short-lived FOR UPDATE lock prevents concurrent claims.
*exclude_ids* prevents re-processing the same file if the impl fails.
"""
file_id = db_session.execute(
stmt = (
select(UserFile.id)
.where(UserFile.status == UserFileStatus.DELETING)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
).scalar_one_or_none()
# Commit to release the row lock promptly.
)
if exclude_ids:
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
file_id = db_session.execute(stmt).scalar_one_or_none()
db_session.commit()
return file_id
def _claim_next_sync_file(db_session: Session) -> UUID | None:
def _claim_next_sync_file(
db_session: Session,
exclude_ids: set[UUID] | None = None,
) -> UUID | None:
"""Claim the next file needing project/persona sync.
No status transition needed — the impl clears the sync flags on
success. The short-lived FOR UPDATE lock prevents concurrent claims.
*exclude_ids* prevents re-processing the same file if the impl fails.
"""
file_id = db_session.execute(
stmt = (
select(UserFile.id)
.where(
sa.and_(
@@ -113,7 +123,10 @@ def _claim_next_sync_file(db_session: Session) -> UUID | None:
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
).scalar_one_or_none()
)
if exclude_ids:
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
file_id = db_session.execute(stmt).scalar_one_or_none()
db_session.commit()
return file_id
@@ -135,11 +148,14 @@ def drain_processing_loop(tenant_id: str) -> None:
file_id = _claim_next_processing_file(session)
if file_id is None:
break
process_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
try:
process_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to process user file {file_id}")
def drain_delete_loop(tenant_id: str) -> None:
@@ -149,16 +165,21 @@ def drain_delete_loop(tenant_id: str) -> None:
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
failed: set[UUID] = set()
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_deleting_file(session)
file_id = _claim_next_deleting_file(session, exclude_ids=failed)
if file_id is None:
break
delete_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
try:
delete_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to delete user file {file_id}")
failed.add(file_id)
def drain_project_sync_loop(tenant_id: str) -> None:
@@ -168,13 +189,18 @@ def drain_project_sync_loop(tenant_id: str) -> None:
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
failed: set[UUID] = set()
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_sync_file(session)
file_id = _claim_next_sync_file(session, exclude_ids=failed)
if file_id is None:
break
project_sync_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
try:
project_sync_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
except Exception:
logger.exception(f"Failed to sync user file {file_id}")
failed.add(file_id)

View File

@@ -12,9 +12,15 @@ def _build_redis_backend(tenant_id: str) -> CacheBackend:
return RedisCacheBackend(redis_pool.get_client(tenant_id))
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
from onyx.cache.postgres_backend import PostgresCacheBackend
return PostgresCacheBackend(tenant_id)
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
CacheBackendType.REDIS: _build_redis_backend,
# CacheBackendType.POSTGRES will be added in a follow-up PR.
CacheBackendType.POSTGRES: _build_postgres_backend,
}

View File

@@ -1,6 +1,9 @@
import abc
from enum import Enum
TTL_KEY_NOT_FOUND = -2
TTL_NO_EXPIRY = -1
class CacheBackendType(str, Enum):
REDIS = "redis"
@@ -26,6 +29,14 @@ class CacheLock(abc.ABC):
def owned(self) -> bool:
raise NotImplementedError
def __enter__(self) -> "CacheLock":
if not self.acquire():
raise RuntimeError("Failed to acquire lock")
return self
def __exit__(self, *args: object) -> None:
self.release()
class CacheBackend(abc.ABC):
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
@@ -65,7 +76,11 @@ class CacheBackend(abc.ABC):
@abc.abstractmethod
def ttl(self, key: str) -> int:
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
"""Return remaining TTL in seconds.
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
"""
raise NotImplementedError
# -- distributed lock --------------------------------------------------

323
backend/onyx/cache/postgres_backend.py vendored Normal file
View File

@@ -0,0 +1,323 @@
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
for distributed locking, and a polling loop for the BLPOP pattern.
"""
import hashlib
import struct
import time
import uuid
from contextlib import AbstractContextManager
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
from onyx.db.models import CacheStore
_LIST_KEY_PREFIX = "_q:"
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
# lists whose names share a prefix (e.g. _q:mylist2:...).
_LIST_KEY_RANGE_TERMINATOR = ";"
_LIST_ITEM_TTL_SECONDS = 3600
_LOCK_POLL_INTERVAL = 0.1
_BLPOP_POLL_INTERVAL = 0.25
def _list_item_key(key: str) -> str:
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
collision when concurrent rpush calls occur within the same nanosecond.
"""
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
def _to_bytes(value: str | bytes | int | float) -> bytes:
if isinstance(value, bytes):
return value
return str(value).encode()
# ------------------------------------------------------------------
# Lock
# ------------------------------------------------------------------
class PostgresCacheLock(CacheLock):
"""Advisory-lock-based distributed lock.
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
to the session's connection; releasing or closing the session frees it.
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
``timeout`` seconds. They are released when ``release()`` is
called or when the session is closed.
"""
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
self._lock_id = lock_id
self._timeout = timeout
self._tenant_id = tenant_id
self._session_cm: AbstractContextManager[Session] | None = None
self._session: Session | None = None
self._acquired = False
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
self._session = self._session_cm.__enter__()
try:
if not blocking:
return self._try_lock()
effective_timeout = blocking_timeout or self._timeout
deadline = (
(time.monotonic() + effective_timeout) if effective_timeout else None
)
while True:
if self._try_lock():
return True
if deadline is not None and time.monotonic() >= deadline:
return False
time.sleep(_LOCK_POLL_INTERVAL)
finally:
if not self._acquired:
self._close_session()
def release(self) -> None:
if not self._acquired or self._session is None:
return
try:
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
finally:
self._acquired = False
self._close_session()
def owned(self) -> bool:
return self._acquired
def _close_session(self) -> None:
if self._session_cm is not None:
try:
self._session_cm.__exit__(None, None, None)
finally:
self._session_cm = None
self._session = None
def _try_lock(self) -> bool:
assert self._session is not None
result = self._session.execute(
select(func.pg_try_advisory_lock(self._lock_id))
).scalar()
if result:
self._acquired = True
return True
return False
# ------------------------------------------------------------------
# Backend
# ------------------------------------------------------------------
class PostgresCacheBackend(CacheBackend):
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
Each operation opens and closes its own database session so the backend
is safe to share across threads. Tenant isolation is handled by
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
"""
def __init__(self, tenant_id: str) -> None:
self._tenant_id = tenant_id
# -- basic key/value ---------------------------------------------------
def get(self, key: str) -> bytes | None:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.value).where(
CacheStore.key == key,
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
value = session.execute(stmt).scalar_one_or_none()
if value is None:
return None
return bytes(value)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
value_bytes = _to_bytes(value)
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=ex)
if ex is not None
else None
)
stmt = (
pg_insert(CacheStore)
.values(key=key, value=value_bytes, expires_at=expires_at)
.on_conflict_do_update(
index_elements=[CacheStore.key],
set_={"value": value_bytes, "expires_at": expires_at},
)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def delete(self, key: str) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(delete(CacheStore).where(CacheStore.key == key))
session.commit()
def exists(self, key: str) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = (
select(CacheStore.key)
.where(
CacheStore.key == key,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.limit(1)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
return session.execute(stmt).first() is not None
# -- TTL ---------------------------------------------------------------
def expire(self, key: str, seconds: int) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
stmt = (
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def ttl(self, key: str) -> int:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
result = session.execute(stmt).first()
if result is None:
return TTL_KEY_NOT_FOUND
expires_at: datetime | None = result[0]
if expires_at is None:
return TTL_NO_EXPIRY
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
if remaining <= 0:
return TTL_KEY_NOT_FOUND
return int(remaining)
# -- distributed lock --------------------------------------------------
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
return PostgresCacheLock(
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
)
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
def rpush(self, key: str, value: str | bytes) -> None:
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
if timeout <= 0:
raise ValueError(
"PostgresCacheBackend.blpop requires timeout > 0. "
"timeout=0 would block the calling thread indefinitely "
"with no way to interrupt short of process termination."
)
from onyx.db.engine.sql_engine import get_session_with_tenant
deadline = time.monotonic() + timeout
while True:
for key in keys:
lower = f"{_LIST_KEY_PREFIX}{key}:"
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
stmt = (
select(CacheStore)
.where(
CacheStore.key >= lower,
CacheStore.key < upper,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.order_by(CacheStore.key)
.limit(1)
.with_for_update(skip_locked=True)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
row = session.execute(stmt).scalars().first()
if row is not None:
value = bytes(row.value) if row.value else b""
session.delete(row)
session.commit()
return (key.encode(), value)
if time.monotonic() >= deadline:
return None
time.sleep(_BLPOP_POLL_INTERVAL)
# -- helpers -----------------------------------------------------------
def _lock_id_for(self, name: str) -> int:
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
return struct.unpack("q", h[:8])[0]
# ------------------------------------------------------------------
# Periodic cleanup
# ------------------------------------------------------------------
def cleanup_expired_cache_entries() -> None:
"""Delete rows whose ``expires_at`` is in the past.
Called by the periodic poller every 5 minutes.
"""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(
delete(CacheStore).where(
CacheStore.expires_at.is_not(None),
CacheStore.expires_at < func.now(),
)
)
session.commit()

View File

@@ -1,57 +1,52 @@
from uuid import UUID
from redis.client import Redis
from onyx.cache.interface import CacheBackend
# Redis key prefixes for chat message processing
PREFIX = "chatprocessing"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 30 * 60 # 30 minutes
def _get_fence_key(chat_session_id: UUID) -> str:
"""
Generate the Redis key for a chat session processing a message.
"""Generate the cache key for a chat session processing fence.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string (tenant_id is automatically added by the Redis client)
The fence key string. Tenant isolation is handled automatically
by the cache backend (Redis key-prefixing or Postgres schema routing).
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_processing_status(
chat_session_id: UUID, redis_client: Redis, value: bool
chat_session_id: UUID, cache: CacheBackend, value: bool
) -> None:
"""
Set or clear the fence for a chat session processing a message.
"""Set or clear the fence for a chat session processing a message.
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
If the key exists, a message is being processed.
Args:
chat_session_id: The UUID of the chat session
redis_client: The Redis client to use
cache: Tenant-aware cache backend
value: True to set the fence, False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if value:
redis_client.set(fence_key, 0, ex=FENCE_TTL)
cache.set(fence_key, 0, ex=FENCE_TTL)
else:
redis_client.delete(fence_key)
cache.delete(fence_key)
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session is processing a message.
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
"""Check if the chat session is processing a message.
Args:
chat_session_id: The UUID of the chat session
redis_client: The Redis client to use
cache: Tenant-aware cache backend
Returns:
True if the chat session is processing a message, False otherwise
"""
fence_key = _get_fence_key(chat_session_id)
return bool(redis_client.exists(fence_key))
return cache.exists(_get_fence_key(chat_session_id))

View File

@@ -11,9 +11,10 @@ from contextvars import Token
from uuid import UUID
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.cache.factory import get_cache_backend
from onyx.cache.interface import CacheBackend
from onyx.chat.chat_processing_checker import set_processing_status
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import run_chat_loop_with_state_containers
@@ -79,7 +80,6 @@ from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
llm: LLM | None = None
chat_session: ChatSession | None = None
redis_client: Redis | None = None
cache: CacheBackend | None = None
user_id = user.id
if user.is_anonymous:
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
)
simple_chat_history.insert(0, summary_simple)
redis_client = get_redis_client()
cache = get_cache_backend()
reset_cancel_status(
chat_session.id,
redis_client,
cache,
)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, redis_client)
return check_stop_signal(chat_session.id, cache)
set_processing_status(
chat_session_id=chat_session.id,
redis_client=redis_client,
cache=cache,
value=True,
)
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
reset_llm_mock_response(mock_response_token)
try:
if redis_client is not None and chat_session is not None:
if cache is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
redis_client=redis_client,
cache=cache,
value=False,
)
except Exception:

View File

@@ -1,65 +1,58 @@
from uuid import UUID
from redis.client import Redis
from onyx.cache.interface import CacheBackend
# Redis key prefixes for chat session stop signals
PREFIX = "chatsessionstop"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
FENCE_TTL = 10 * 60 # 10 minutes
def _get_fence_key(chat_session_id: UUID) -> str:
"""
Generate the Redis key for a chat session stop signal fence.
"""Generate the cache key for a chat session stop signal fence.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string (tenant_id is automatically added by the Redis client)
The fence key string. Tenant isolation is handled automatically
by the cache backend (Redis key-prefixing or Postgres schema routing).
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
"""
Set or clear the stop signal fence for a chat session.
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
"""Set or clear the stop signal fence for a chat session.
Args:
chat_session_id: The UUID of the chat session
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
cache: Tenant-aware cache backend
value: True to set the fence (stop signal), False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if not value:
redis_client.delete(fence_key)
cache.delete(fence_key)
return
redis_client.set(fence_key, 0, ex=FENCE_TTL)
cache.set(fence_key, 0, ex=FENCE_TTL)
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session should continue (not stopped).
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
"""Check if the chat session should continue (not stopped).
Args:
chat_session_id: The UUID of the chat session to check
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
cache: Tenant-aware cache backend
Returns:
True if the session should continue, False if it should stop
"""
fence_key = _get_fence_key(chat_session_id)
return not bool(redis_client.exists(fence_key))
return not cache.exists(_get_fence_key(chat_session_id))
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
"""
Clear the stop signal for a chat session.
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
"""Clear the stop signal for a chat session.
Args:
chat_session_id: The UUID of the chat session
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
cache: Tenant-aware cache backend
"""
fence_key = _get_fence_key(chat_session_id)
redis_client.delete(fence_key)
cache.delete(_get_fence_key(chat_session_id))

View File

@@ -4926,7 +4926,9 @@ class ScimUserMapping(Base):
__tablename__ = "scim_user_mapping"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
external_id: Mapped[str | None] = mapped_column(
String, unique=True, index=True, nullable=True
)
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
@@ -4983,3 +4985,25 @@ class CodeInterpreterServer(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
class CacheStore(Base):
"""Key-value cache table used by ``PostgresCacheBackend``.
Replaces Redis for simple KV caching, locks, and list operations
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
Intentionally separate from ``KVStore``:
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
- Holds ephemeral data (tokens, stop signals, lock state) not
persistent application config, so cleanup can be aggressive.
"""
__tablename__ = "cache_store"
key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
expires_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)

View File

@@ -52,7 +52,7 @@ def create_user_files(
) -> CategorizedFilesResult:
# Categorize the files
categorized_files = categorize_uploaded_files(files, db_session)
categorized_files = categorize_uploaded_files(files)
# NOTE: At the moment, zip metadata is not used for user files.
# Should revisit to decide whether this should be a feature.
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)

View File

@@ -4,39 +4,33 @@ import base64
import json
import uuid
from typing import Any
from typing import cast
from typing import Dict
from typing import Optional
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Redis key prefix for OAuth state
OAUTH_STATE_PREFIX = "federated_oauth"
# Default TTL for OAuth state (5 minutes)
OAUTH_STATE_TTL = 300
OAUTH_STATE_TTL = 300 # 5 minutes
class OAuthSession:
"""Represents an OAuth session stored in Redis."""
"""Represents an OAuth session stored in the cache backend."""
def __init__(
self,
federated_connector_id: int,
user_id: str,
redirect_uri: Optional[str] = None,
additional_data: Optional[Dict[str, Any]] = None,
redirect_uri: str | None = None,
additional_data: dict[str, Any] | None = None,
):
self.federated_connector_id = federated_connector_id
self.user_id = user_id
self.redirect_uri = redirect_uri
self.additional_data = additional_data or {}
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for Redis storage."""
def to_dict(self) -> dict[str, Any]:
return {
"federated_connector_id": self.federated_connector_id,
"user_id": self.user_id,
@@ -45,8 +39,7 @@ class OAuthSession:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
"""Create from dictionary retrieved from Redis."""
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
return cls(
federated_connector_id=data["federated_connector_id"],
user_id=data["user_id"],
@@ -58,31 +51,27 @@ class OAuthSession:
def generate_oauth_state(
federated_connector_id: int,
user_id: str,
redirect_uri: Optional[str] = None,
additional_data: Optional[Dict[str, Any]] = None,
redirect_uri: str | None = None,
additional_data: dict[str, Any] | None = None,
ttl: int = OAUTH_STATE_TTL,
) -> str:
"""
Generate a secure state parameter and store session data in Redis.
Generate a secure state parameter and store session data in the cache backend.
Args:
federated_connector_id: ID of the federated connector
user_id: ID of the user initiating OAuth
redirect_uri: Optional redirect URI after OAuth completion
additional_data: Any additional data to store with the session
ttl: Time-to-live in seconds for the Redis key
ttl: Time-to-live in seconds for the cache key
Returns:
Base64-encoded state parameter
"""
# Generate a random UUID for the state
state_uuid = uuid.uuid4()
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
# Convert UUID to base64 for URL-safe state parameter
state_bytes = state_uuid.bytes
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
# Create session object
session = OAuthSession(
federated_connector_id=federated_connector_id,
user_id=user_id,
@@ -90,15 +79,9 @@ def generate_oauth_state(
additional_data=additional_data,
)
# Store in Redis with TTL
redis_client = get_redis_client()
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
redis_client.set(
redis_key,
json.dumps(session.to_dict()),
ex=ttl,
)
cache = get_cache_backend()
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
logger.info(
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
@@ -125,18 +108,15 @@ def verify_oauth_state(state: str) -> OAuthSession:
state_bytes = base64.urlsafe_b64decode(padded_state)
state_uuid = uuid.UUID(bytes=state_bytes)
# Look up in Redis
redis_client = get_redis_client()
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
cache = get_cache_backend()
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
session_data = cast(bytes, redis_client.get(redis_key))
session_data = cache.get(cache_key)
if not session_data:
raise ValueError(f"OAuth state not found in Redis: {state}")
raise ValueError(f"OAuth state not found: {state}")
# Delete the key after retrieval (one-time use)
redis_client.delete(redis_key)
cache.delete(cache_key)
# Parse and return session
session_dict = json.loads(session_data)
return OAuthSession.from_dict(session_dict)

View File

@@ -1,13 +1,11 @@
import json
from typing import cast
from redis.client import Redis
from onyx.cache.interface import CacheBackend
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import KVStore
from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@@ -20,22 +18,27 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self, redis_client: Redis | None = None) -> None:
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
self.redis_client = get_redis_client()
def __init__(self, cache: CacheBackend | None = None) -> None:
self._cache = cache
def _get_cache(self) -> CacheBackend:
if self._cache is None:
from onyx.cache.factory import get_cache_backend
self._cache = get_cache_backend()
return self._cache
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Redis, but encrypted in Postgres
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
try:
self.redis_client.set(
self._get_cache().set(
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
)
except Exception as e:
# Fallback gracefully to Postgres if Redis fails
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
# Fallback gracefully to Postgres if Cache backend fails
logger.error(
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
)
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
@@ -53,16 +56,12 @@ class PgRedisKVStore(KeyValueStore):
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
if not refresh_cache:
try:
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
if redis_value:
if not isinstance(redis_value, bytes):
raise ValueError(
f"Redis value for key '{key}' is not a bytes object"
)
return json.loads(redis_value.decode("utf-8"))
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
if cached is not None:
return json.loads(cached.decode("utf-8"))
except Exception as e:
logger.error(
f"Failed to get value from Redis for key '{key}': {str(e)}"
f"Failed to get value from cache for key '{key}': {str(e)}"
)
with get_session_with_current_tenant() as db_session:
@@ -79,21 +78,21 @@ class PgRedisKVStore(KeyValueStore):
value = None
try:
self.redis_client.set(
self._get_cache().set(
REDIS_KEY_PREFIX + key,
json.dumps(value),
ex=KV_REDIS_KEY_EXPIRATION,
)
except Exception as e:
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
return cast(JSON_ro, value)
def delete(self, key: str) -> None:
try:
self.redis_client.delete(REDIS_KEY_PREFIX + key)
self._get_cache().delete(REDIS_KEY_PREFIX + key)
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
with get_session_with_current_tenant() as db_session:
result = db_session.query(KVStore).filter_by(key=key).delete()

View File

@@ -13,44 +13,38 @@ from datetime import datetime
import httpx
from sqlalchemy.orm import Session
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import sync_auto_mode_models
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Redis key for caching the last updated timestamp (per-tenant)
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
def _get_cached_last_updated_at() -> datetime | None:
"""Get the cached last_updated_at timestamp from Redis."""
try:
redis_client = get_redis_client()
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
if value and isinstance(value, bytes):
# Value is bytes, decode to string then parse as ISO format
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
if value is not None:
return datetime.fromisoformat(value.decode("utf-8"))
except Exception as e:
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
logger.warning(f"Failed to get cached last_updated_at: {e}")
return None
def _set_cached_last_updated_at(updated_at: datetime) -> None:
"""Set the cached last_updated_at timestamp in Redis."""
try:
redis_client = get_redis_client()
# Store as ISO format string, with 24 hour expiration
redis_client.set(
_REDIS_KEY_LAST_UPDATED_AT,
get_cache_backend().set(
_CACHE_KEY_LAST_UPDATED_AT,
updated_at.isoformat(),
ex=60 * 60 * 24, # 24 hours
ex=_CACHE_TTL_SECONDS,
)
except Exception as e:
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
logger.warning(f"Failed to set cached last_updated_at: {e}")
def fetch_llm_recommendations_from_github(
@@ -148,9 +142,8 @@ def sync_llm_models_from_github(
def reset_cache() -> None:
"""Reset the cache timestamp in Redis. Useful for testing."""
"""Reset the cache timestamp. Useful for testing."""
try:
redis_client = get_redis_client()
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
except Exception as e:
logger.warning(f"Failed to reset cache in Redis: {e}")
logger.warning(f"Failed to reset cache: {e}")

View File

@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
pass
def categorize_uploaded_files(
files: list[UploadFile], db_session: Session
) -> CategorizedFiles:
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
"""
Categorize uploaded files based on text extractability and tokenized length.
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
"""
results = CategorizedFiles()
default_model = fetch_default_llm_model(db_session)
llm = get_default_llm()
model_name = default_model.name if default_model else None
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
tokenizer = get_tokenizer(
model_name=llm.config.model_name, provider_type=llm.config.model_provider
)
# Check if threshold checks should be skipped
skip_threshold = False

View File

@@ -8,10 +8,10 @@ import httpx
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.cache.factory import get_shared_cache_backend
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.release_notes import create_release_notifications_for_versions
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
@@ -113,60 +113,46 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
def get_cached_etag() -> str | None:
"""Get the cached GitHub ETag from Redis."""
redis_client = get_shared_redis_client()
cache = get_shared_cache_backend()
try:
etag = redis_client.get(REDIS_KEY_ETAG)
etag = cache.get(REDIS_KEY_ETAG)
if etag:
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
return etag.decode("utf-8")
return None
except Exception as e:
logger.error(f"Failed to get cached etag from Redis: {e}")
logger.error(f"Failed to get cached etag: {e}")
return None
def get_last_fetch_time() -> datetime | None:
"""Get the last fetch timestamp from Redis."""
redis_client = get_shared_redis_client()
cache = get_shared_cache_backend()
try:
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
if not fetched_at_str:
raw = cache.get(REDIS_KEY_FETCHED_AT)
if not raw:
return None
decoded = (
fetched_at_str.decode("utf-8")
if isinstance(fetched_at_str, bytes)
else str(fetched_at_str)
)
last_fetch = datetime.fromisoformat(decoded)
# Defensively ensure timezone awareness
# fromisoformat() returns naive datetime if input lacks timezone
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
if last_fetch.tzinfo is None:
# Assume UTC for naive datetimes
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
else:
# Convert to UTC if timezone-aware
last_fetch = last_fetch.astimezone(timezone.utc)
return last_fetch
except Exception as e:
logger.error(f"Failed to get last fetch time from Redis: {e}")
logger.error(f"Failed to get last fetch time from cache: {e}")
return None
def save_fetch_metadata(etag: str | None) -> None:
"""Save ETag and fetch timestamp to Redis."""
redis_client = get_shared_redis_client()
cache = get_shared_cache_backend()
now = datetime.now(timezone.utc)
try:
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
if etag:
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
except Exception as e:
logger.error(f"Failed to save fetch metadata to Redis: {e}")
logger.error(f"Failed to save fetch metadata to cache: {e}")
def is_cache_stale() -> bool:
@@ -196,11 +182,10 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
if not is_cache_stale():
return
# Acquire lock to prevent concurrent fetches
redis_client = get_shared_redis_client()
lock = redis_client.lock(
cache = get_shared_cache_backend()
lock = cache.lock(
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
timeout=90, # 90 second timeout for the lock
timeout=90,
)
# Non-blocking acquire - if we can't get the lock, another request is handling it

View File

@@ -479,10 +479,20 @@ def put_llm_provider(
@admin_router.delete("/provider/{provider_id}")
def delete_llm_provider(
provider_id: int,
force: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
if not force:
model = fetch_default_llm_model(db_session)
if model and model.llm_provider_id == provider_id:
raise HTTPException(
status_code=400,
detail="Cannot delete the default LLM provider",
)
remove_llm_provider(db_session, provider_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -13,13 +13,13 @@ from fastapi import Request
from fastapi import Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.pat import get_hashed_pat_from_request
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_user
from onyx.cache.factory import get_cache_backend
from onyx.chat.chat_processing_checker import is_chat_session_processing
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import convert_chat_history_basic
@@ -67,7 +67,6 @@ from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.redis.redis_pool import get_redis_client
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
from onyx.server.api_key_usage import check_api_key_usage
from onyx.server.query_and_chat.models import ChatFeedbackRequest
@@ -330,7 +329,7 @@ def get_chat_session(
]
try:
is_processing = is_chat_session_processing(session_id, get_redis_client())
is_processing = is_chat_session_processing(session_id, get_cache_backend())
# Edit the last message to indicate loading (Overriding default message value)
if is_processing and chat_message_details:
last_msg = chat_message_details[-1]
@@ -927,11 +926,10 @@ async def search_chats(
def stop_chat_session(
chat_session_id: UUID,
user: User = Depends(current_user), # noqa: ARG001
redis_client: Redis = Depends(get_redis_client),
) -> dict[str, str]:
"""
Stop a chat session by setting a stop signal in Redis.
Stop a chat session by setting a stop signal.
This endpoint is called by the frontend when the user clicks the stop button.
"""
set_fence(chat_session_id, redis_client, True)
set_fence(chat_session_id, get_cache_backend(), True)
return {"message": "Chat session stopped"}

View File

@@ -1,3 +1,4 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
@@ -6,11 +7,8 @@ from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -33,30 +31,22 @@ def load_settings() -> Settings:
logger.error(f"Error loading settings from KV store: {str(e)}")
settings = Settings()
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
cache = get_cache_backend()
try:
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is not None:
assert isinstance(value, bytes)
anonymous_user_enabled = int(value.decode("utf-8")) == 1
else:
# Default to False
anonymous_user_enabled = False
# Optionally store the default back to Redis
redis_client.set(
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
)
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
except Exception as e:
# Log the error and reset to default
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
anonymous_user_enabled = False
settings.anonymous_user_enabled = anonymous_user_enabled
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
# Override user knowledge setting if disabled via environment variable
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
@@ -66,11 +56,10 @@ def load_settings() -> Settings:
def store_settings(settings: Settings) -> None:
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
cache = get_cache_backend()
if settings.anonymous_user_enabled is not None:
redis_client.set(
cache.set(
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
"1" if settings.anonymous_user_enabled else "0",
ex=SETTINGS_TTL,

View File

@@ -13,9 +13,11 @@ the correct files.
from collections.abc import Generator
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import UUID
from uuid import uuid4
import pytest
import sqlalchemy as sa
from sqlalchemy.orm import Session
from onyx.background.periodic_poller import recover_stuck_user_files
@@ -55,6 +57,32 @@ def _create_user_file(
return uf
def _fake_delete_impl(
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
) -> None:
"""Mock side-effect: delete the row so the drain loop terminates."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(sa.delete(UserFile).where(UserFile.id == UUID(user_file_id)))
session.commit()
def _fake_sync_impl(
user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
) -> None:
"""Mock side-effect: clear sync flags so the drain loop terminates."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(
sa.update(UserFile)
.where(UserFile.id == UUID(user_file_id))
.values(needs_project_sync=False, needs_persona_sync=False)
)
session.commit()
@pytest.fixture()
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
"""Track created UserFile rows and delete them after each test."""
@@ -125,9 +153,9 @@ class TestRecoverDeletingFiles:
) -> None:
user = create_test_user(db_session, "recovery_del")
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
_cleanup_user_files.append(uf)
# Row is deleted by _fake_delete_impl, so no cleanup needed.
mock_impl = MagicMock()
mock_impl = MagicMock(side_effect=_fake_delete_impl)
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
@@ -155,7 +183,7 @@ class TestRecoverSyncFiles:
)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
mock_impl = MagicMock(side_effect=_fake_sync_impl)
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
@@ -179,7 +207,7 @@ class TestRecoverSyncFiles:
)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
mock_impl = MagicMock(side_effect=_fake_sync_impl)
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
@@ -217,3 +245,108 @@ class TestRecoveryMultipleFiles:
f"Expected all {len(files)} files to be recovered. "
f"Missing: {expected_ids - called_ids}"
)
class TestTransientFailures:
"""Drain loops skip failed files, process the rest, and terminate."""
def test_processing_failure_skips_and_continues(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "fail_proc")
uf_fail = _create_user_file(
db_session, user.id, status=UserFileStatus.PROCESSING
)
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
_cleanup_user_files.extend([uf_fail, uf_ok])
fail_id = str(uf_fail.id)
def side_effect(
*, user_file_id: str, tenant_id: str, redis_locking: bool # noqa: ARG001
) -> None:
if user_file_id == fail_id:
raise RuntimeError("transient failure")
mock_impl = MagicMock(side_effect=side_effect)
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert fail_id in called_ids, "Failed file should have been attempted"
assert str(uf_ok.id) in called_ids, "Healthy file should have been processed"
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
assert called_ids.count(str(uf_ok.id)) == 1
def test_delete_failure_skips_and_continues(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "fail_del")
uf_fail = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
uf_ok = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
_cleanup_user_files.append(uf_fail)
fail_id = str(uf_fail.id)
def side_effect(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
if user_file_id == fail_id:
raise RuntimeError("transient failure")
_fake_delete_impl(user_file_id, tenant_id, redis_locking)
mock_impl = MagicMock(side_effect=side_effect)
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert fail_id in called_ids, "Failed file should have been attempted"
assert str(uf_ok.id) in called_ids, "Healthy file should have been deleted"
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
assert called_ids.count(str(uf_ok.id)) == 1
def test_sync_failure_skips_and_continues(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "fail_sync")
uf_fail = _create_user_file(
db_session,
user.id,
status=UserFileStatus.COMPLETED,
needs_project_sync=True,
)
uf_ok = _create_user_file(
db_session,
user.id,
status=UserFileStatus.COMPLETED,
needs_persona_sync=True,
)
_cleanup_user_files.extend([uf_fail, uf_ok])
fail_id = str(uf_fail.id)
def side_effect(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
if user_file_id == fail_id:
raise RuntimeError("transient failure")
_fake_sync_impl(user_file_id, tenant_id, redis_locking)
mock_impl = MagicMock(side_effect=side_effect)
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert fail_id in called_ids, "Failed file should have been attempted"
assert str(uf_ok.id) in called_ids, "Healthy file should have been synced"
assert called_ids.count(fail_id) == 1, "Failed file retried — infinite loop"
assert called_ids.count(str(uf_ok.id)) == 1

View File

@@ -0,0 +1,57 @@
"""Fixtures for cache backend tests.
Requires a running PostgreSQL instance (and Redis for parity tests).
Run with::
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
"""
from collections.abc import Generator
import pytest
from onyx.cache.interface import CacheBackend
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.cache.redis_backend import RedisCacheBackend
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(scope="session", autouse=True)
def _init_db() -> Generator[None, None, None]:
"""Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose)."""
SqlEngine.init_engine(pool_size=5, max_overflow=2)
yield
@pytest.fixture(autouse=True)
def _tenant_context() -> Generator[None, None, None]:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture
def pg_cache() -> PostgresCacheBackend:
return PostgresCacheBackend(TEST_TENANT_ID)
@pytest.fixture
def redis_cache() -> RedisCacheBackend:
from onyx.redis.redis_pool import redis_pool
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
def cache(
request: pytest.FixtureRequest,
pg_cache: PostgresCacheBackend,
redis_cache: RedisCacheBackend,
) -> CacheBackend:
if request.param == "postgres":
return pg_cache
return redis_cache

View File

@@ -0,0 +1,100 @@
"""Parameterized tests that run the same CacheBackend operations against
both Redis and PostgreSQL, asserting identical return values.
Each test runs twice (once per backend) via the ``cache`` fixture defined
in conftest.py.
"""
import time
from uuid import uuid4
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
def _key() -> str:
return f"parity_{uuid4().hex[:12]}"
class TestKVParity:
def test_get_missing(self, cache: CacheBackend) -> None:
assert cache.get(_key()) is None
def test_get_set(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"value")
assert cache.get(k) == b"value"
def test_overwrite(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"a")
cache.set(k, b"b")
assert cache.get(k) == b"b"
def test_set_string(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, "hello")
assert cache.get(k) == b"hello"
def test_set_int(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, 42)
assert cache.get(k) == b"42"
def test_delete(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x")
cache.delete(k)
assert cache.get(k) is None
def test_exists(self, cache: CacheBackend) -> None:
k = _key()
assert not cache.exists(k)
cache.set(k, b"x")
assert cache.exists(k)
class TestTTLParity:
def test_ttl_missing(self, cache: CacheBackend) -> None:
assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x")
assert cache.ttl(k) == TTL_NO_EXPIRY
def test_ttl_remaining(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x", ex=10)
remaining = cache.ttl(k)
assert 8 <= remaining <= 10
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x", ex=1)
assert cache.get(k) == b"x"
time.sleep(1.5)
assert cache.get(k) is None
class TestLockParity:
def test_acquire_release(self, cache: CacheBackend) -> None:
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
assert lock.acquire(blocking=False)
assert lock.owned()
lock.release()
assert not lock.owned()
class TestListParity:
def test_rpush_blpop(self, cache: CacheBackend) -> None:
k = f"parity_list_{uuid4().hex[:8]}"
cache.rpush(k, b"item")
result = cache.blpop([k], timeout=1)
assert result is not None
assert result[1] == b"item"
def test_blpop_timeout(self, cache: CacheBackend) -> None:
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
assert result is None

View File

@@ -0,0 +1,129 @@
"""Tests for PgRedisKVStore's cache layer integration with CacheBackend.
Verifies that the KV store correctly uses the CacheBackend for caching
in front of PostgreSQL: cache hits, cache misses falling through to PG,
cache population after PG reads, cache invalidation on delete, and
graceful degradation when the cache backend raises.
Requires running PostgreSQL.
"""
import json
from collections.abc import Generator
from unittest.mock import MagicMock
import pytest
from sqlalchemy import delete
from onyx.cache.interface import CacheBackend
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import CacheStore
from onyx.db.models import KVStore
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.key_value_store.store import PgRedisKVStore
from onyx.key_value_store.store import REDIS_KEY_PREFIX
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(autouse=True)
def _clean_kv() -> Generator[None, None, None]:
yield
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
session.execute(delete(KVStore))
session.execute(delete(CacheStore))
session.commit()
@pytest.fixture
def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore:
return PgRedisKVStore(cache=pg_cache)
class TestStoreAndLoad:
def test_store_populates_cache_and_pg(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k1", {"hello": "world"})
cached = pg_cache.get(REDIS_KEY_PREFIX + "k1")
assert cached is not None
assert json.loads(cached) == {"hello": "world"}
loaded = kv_store.load("k1")
assert loaded == {"hello": "world"}
def test_load_returns_cached_value_without_pg_hit(
self, pg_cache: PostgresCacheBackend
) -> None:
"""If the cache already has the value, PG should not be queried."""
pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"}))
kv = PgRedisKVStore(cache=pg_cache)
assert kv.load("cached_only") == {"from": "cache"}
def test_load_falls_through_to_pg_on_cache_miss(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k2", [1, 2, 3])
pg_cache.delete(REDIS_KEY_PREFIX + "k2")
assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None
loaded = kv_store.load("k2")
assert loaded == [1, 2, 3]
repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2")
assert repopulated is not None
assert json.loads(repopulated) == [1, 2, 3]
def test_load_with_refresh_cache_skips_cache(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k3", "original")
pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale"))
loaded = kv_store.load("k3", refresh_cache=True)
assert loaded == "original"
class TestDelete:
def test_delete_removes_from_cache_and_pg(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("del_me", "bye")
kv_store.delete("del_me")
assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None
with pytest.raises(KvKeyNotFoundError):
kv_store.load("del_me")
def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None:
with pytest.raises(KvKeyNotFoundError):
kv_store.delete("nonexistent")
class TestCacheFailureGracefulDegradation:
def test_store_succeeds_when_cache_set_raises(self) -> None:
failing_cache = MagicMock(spec=CacheBackend)
failing_cache.set.side_effect = ConnectionError("cache down")
kv = PgRedisKVStore(cache=failing_cache)
kv.store("resilient", {"data": True})
working_cache = MagicMock(spec=CacheBackend)
working_cache.get.return_value = None
kv_reader = PgRedisKVStore(cache=working_cache)
loaded = kv_reader.load("resilient")
assert loaded == {"data": True}
def test_load_falls_through_when_cache_get_raises(self) -> None:
failing_cache = MagicMock(spec=CacheBackend)
failing_cache.get.side_effect = ConnectionError("cache down")
failing_cache.set.side_effect = ConnectionError("cache down")
kv = PgRedisKVStore(cache=failing_cache)
kv.store("survive", 42)
loaded = kv.load("survive")
assert loaded == 42

View File

@@ -0,0 +1,229 @@
"""Tests for PostgresCacheBackend against real PostgreSQL.
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
locks (acquire / release / contention), list operations (rpush / blpop),
and the periodic cleanup function.
"""
import time
from uuid import uuid4
from sqlalchemy import select
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.db.models import CacheStore
def _key() -> str:
return f"test_{uuid4().hex[:12]}"
# ------------------------------------------------------------------
# Basic KV
# ------------------------------------------------------------------
class TestKV:
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"hello")
assert pg_cache.get(k) == b"hello"
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
assert pg_cache.get(_key()) is None
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"first")
pg_cache.set(k, b"second")
assert pg_cache.get(k) == b"second"
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, "string_val")
assert pg_cache.get(k) == b"string_val"
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, 42)
assert pg_cache.get(k) == b"42"
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"to_delete")
pg_cache.delete(k)
assert pg_cache.get(k) is None
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
pg_cache.delete(_key())
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
assert not pg_cache.exists(k)
pg_cache.set(k, b"x")
assert pg_cache.exists(k)
# ------------------------------------------------------------------
# TTL
# ------------------------------------------------------------------
class TestTTL:
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"ephemeral", ex=1)
assert pg_cache.get(k) == b"ephemeral"
time.sleep(1.5)
assert pg_cache.get(k) is None
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"forever")
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=10)
remaining = pg_cache.ttl(k)
assert 8 <= remaining <= 10
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=1)
time.sleep(1.5)
assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x")
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
pg_cache.expire(k, 10)
assert 8 <= pg_cache.ttl(k) <= 10
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=1)
assert pg_cache.exists(k)
time.sleep(1.5)
assert not pg_cache.exists(k)
# ------------------------------------------------------------------
# Locks
# ------------------------------------------------------------------
class TestLock:
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
assert lock.acquire(blocking=False)
assert lock.owned()
lock.release()
assert not lock.owned()
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
name = f"contention_{uuid4().hex[:8]}"
lock1 = pg_cache.lock(name)
lock2 = pg_cache.lock(name)
assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
assert lock2.acquire(blocking=False)
lock2.release()
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
assert lock.owned()
assert not lock.owned()
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
name = f"timeout_{uuid4().hex[:8]}"
holder = pg_cache.lock(name)
holder.acquire(blocking=False)
waiter = pg_cache.lock(name, timeout=0.3)
start = time.monotonic()
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
elapsed = time.monotonic() - start
assert elapsed >= 0.25
holder.release()
# ------------------------------------------------------------------
# List (rpush / blpop)
# ------------------------------------------------------------------
class TestList:
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
k = f"list_{uuid4().hex[:8]}"
pg_cache.rpush(k, b"item1")
result = pg_cache.blpop([k], timeout=1)
assert result is not None
assert result == (k.encode(), b"item1")
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
assert result is None
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
k = f"fifo_{uuid4().hex[:8]}"
pg_cache.rpush(k, b"first")
time.sleep(0.01)
pg_cache.rpush(k, b"second")
r1 = pg_cache.blpop([k], timeout=1)
r2 = pg_cache.blpop([k], timeout=1)
assert r1 is not None and r1[1] == b"first"
assert r2 is not None and r2[1] == b"second"
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
k1 = f"mk1_{uuid4().hex[:8]}"
k2 = f"mk2_{uuid4().hex[:8]}"
pg_cache.rpush(k2, b"from_k2")
result = pg_cache.blpop([k1, k2], timeout=1)
assert result is not None
assert result == (k2.encode(), b"from_k2")
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
class TestCleanup:
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
from onyx.db.engine.sql_engine import get_session_with_current_tenant
k = _key()
pg_cache.set(k, b"stale", ex=1)
time.sleep(1.5)
cleanup_expired_cache_entries()
stmt = select(CacheStore.key).where(CacheStore.key == k)
with get_session_with_current_tenant() as session:
row = session.execute(stmt).first()
assert row is None, "expired row should be physically deleted"
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"fresh", ex=300)
cleanup_expired_cache_entries()
assert pg_cache.get(k) == b"fresh"
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"permanent")
cleanup_expired_cache_entries()
assert pg_cache.get(k) == b"permanent"

View File

@@ -386,6 +386,261 @@ def test_delete_llm_provider(
assert provider_data is None
def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG001
"""Deleting the default LLM provider should return 400."""
admin_user = UserManager.create(name="admin_user")
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "test-provider-default-delete",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000000",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Set this provider as the default
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
headers=admin_user.headers,
json={
"provider_id": created_provider["id"],
"model_name": "gpt-4o-mini",
},
)
assert set_default_response.status_code == 200
# Attempt to delete the default provider — should be rejected
delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
headers=admin_user.headers,
)
assert delete_response.status_code == 400
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
# Verify provider still exists
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
def test_delete_non_default_llm_provider_with_default_set(
reset: None, # noqa: ARG001
) -> None:
"""Deleting a non-default provider should succeed even when a default is set."""
admin_user = UserManager.create(name="admin_user")
# Create two providers
response_default = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "default-provider",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000000",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert response_default.status_code == 200
default_provider = response_default.json()
response_other = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "other-provider",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000000",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert response_other.status_code == 200
other_provider = response_other.json()
# Set the first provider as default
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
headers=admin_user.headers,
json={
"provider_id": default_provider["id"],
"model_name": "gpt-4o-mini",
},
)
assert set_default_response.status_code == 200
# Delete the non-default provider — should succeed
delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{other_provider['id']}",
headers=admin_user.headers,
)
assert delete_response.status_code == 200
# Verify the non-default provider is gone
provider_data = _get_provider_by_id(admin_user, other_provider["id"])
assert provider_data is None
# Verify the default provider still exists
default_data = _get_provider_by_id(admin_user, default_provider["id"])
assert default_data is not None
def test_force_delete_default_llm_provider(
reset: None, # noqa: ARG001
) -> None:
"""Force-deleting the default LLM provider should succeed."""
admin_user = UserManager.create(name="admin_user")
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "test-provider-force-delete",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000000",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Set this provider as the default
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
headers=admin_user.headers,
json={
"provider_id": created_provider["id"],
"model_name": "gpt-4o-mini",
},
)
assert set_default_response.status_code == 200
# Attempt to delete without force — should be rejected
delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
headers=admin_user.headers,
)
assert delete_response.status_code == 400
# Force delete — should succeed
force_delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}?force=true",
headers=admin_user.headers,
)
assert force_delete_response.status_code == 200
# Verify provider is gone
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is None
def test_delete_default_vision_provider_clears_vision_default(
reset: None, # noqa: ARG001
) -> None:
"""Deleting the default vision provider should succeed and clear the vision default."""
admin_user = UserManager.create(name="admin_user")
# Create a text provider and set it as default (so we have a default text provider)
text_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "text-provider",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000001",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert text_response.status_code == 200
text_provider = text_response.json()
_set_default_provider(admin_user, text_provider["id"], "gpt-4o-mini")
# Create a vision provider and set it as default vision
vision_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "vision-provider",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000002",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
supports_image_input=True,
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert vision_response.status_code == 200
vision_provider = vision_response.json()
_set_default_vision_provider(admin_user, vision_provider["id"], "gpt-4o")
# Verify vision default is set
data = _get_providers_admin(admin_user)
assert data is not None
_, _, vision_default = _unpack_data(data)
assert vision_default is not None
assert vision_default["provider_id"] == vision_provider["id"]
# Delete the vision provider — should succeed (only text default is protected)
delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{vision_provider['id']}",
headers=admin_user.headers,
)
assert delete_response.status_code == 200
# Verify the vision provider is gone
provider_data = _get_provider_by_id(admin_user, vision_provider["id"])
assert provider_data is None
# Verify there is no default vision provider
data = _get_providers_admin(admin_user)
assert data is not None
_, text_default, vision_default = _unpack_data(data)
assert vision_default is None
# Verify the text default is still intact
assert text_default is not None
assert text_default["provider_id"] == text_provider["id"]
def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
"""Creating a provider with a name that already exists should return 400."""
admin_user = UserManager.create(name="admin_user")

View File

@@ -0,0 +1,166 @@
"""Unit tests for stop_signal_checker and chat_processing_checker.
These modules are safety-critical — they control whether a chat stream
continues or stops. The tests use a simple in-memory CacheBackend stub
so no external services are needed.
"""
from uuid import uuid4
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.chat.chat_processing_checker import is_chat_session_processing
from onyx.chat.chat_processing_checker import set_processing_status
from onyx.chat.stop_signal_checker import FENCE_TTL
from onyx.chat.stop_signal_checker import is_connected
from onyx.chat.stop_signal_checker import reset_cancel_status
from onyx.chat.stop_signal_checker import set_fence
class _MemoryCacheBackend(CacheBackend):
"""Minimal in-memory CacheBackend for unit tests."""
def __init__(self) -> None:
self._store: dict[str, bytes] = {}
def get(self, key: str) -> bytes | None:
return self._store.get(key)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None, # noqa: ARG002
) -> None:
if isinstance(value, bytes):
self._store[key] = value
else:
self._store[key] = str(value).encode()
def delete(self, key: str) -> None:
self._store.pop(key, None)
def exists(self, key: str) -> bool:
return key in self._store
def expire(self, key: str, seconds: int) -> None:
pass
def ttl(self, key: str) -> int:
return -2 if key not in self._store else -1
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
raise NotImplementedError
def rpush(self, key: str, value: str | bytes) -> None:
raise NotImplementedError
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
raise NotImplementedError
# ── stop_signal_checker ──────────────────────────────────────────────
class TestSetFence:
def test_set_fence_true_creates_key(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, True)
assert not is_connected(sid, cache)
def test_set_fence_false_removes_key(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, True)
set_fence(sid, cache, False)
assert is_connected(sid, cache)
def test_set_fence_false_noop_when_absent(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, False)
assert is_connected(sid, cache)
def test_set_fence_uses_ttl(self) -> None:
"""Verify set_fence passes ex=FENCE_TTL to cache.set."""
calls: list[dict[str, object]] = []
cache = _MemoryCacheBackend()
original_set = cache.set
def tracking_set(
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
calls.append({"key": key, "ex": ex})
original_set(key, value, ex=ex)
cache.set = tracking_set # type: ignore[method-assign]
set_fence(uuid4(), cache, True)
assert len(calls) == 1
assert calls[0]["ex"] == FENCE_TTL
class TestIsConnected:
def test_connected_when_no_fence(self) -> None:
cache = _MemoryCacheBackend()
assert is_connected(uuid4(), cache)
def test_disconnected_when_fence_set(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, True)
assert not is_connected(sid, cache)
def test_sessions_are_isolated(self) -> None:
cache = _MemoryCacheBackend()
sid1, sid2 = uuid4(), uuid4()
set_fence(sid1, cache, True)
assert not is_connected(sid1, cache)
assert is_connected(sid2, cache)
class TestResetCancelStatus:
def test_clears_fence(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, True)
reset_cancel_status(sid, cache)
assert is_connected(sid, cache)
def test_noop_when_no_fence(self) -> None:
cache = _MemoryCacheBackend()
reset_cancel_status(uuid4(), cache)
# ── chat_processing_checker ──────────────────────────────────────────
class TestSetProcessingStatus:
def test_set_true_marks_processing(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_processing_status(sid, cache, True)
assert is_chat_session_processing(sid, cache)
def test_set_false_clears_processing(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_processing_status(sid, cache, True)
set_processing_status(sid, cache, False)
assert not is_chat_session_processing(sid, cache)
class TestIsChatSessionProcessing:
def test_not_processing_by_default(self) -> None:
cache = _MemoryCacheBackend()
assert not is_chat_session_processing(uuid4(), cache)
def test_sessions_are_isolated(self) -> None:
cache = _MemoryCacheBackend()
sid1, sid2 = uuid4(), uuid4()
set_processing_status(sid1, cache, True)
assert is_chat_session_processing(sid1, cache)
assert not is_chat_session_processing(sid2, cache)

View File

@@ -0,0 +1,163 @@
"""Unit tests for federated OAuth state generation and verification.
Uses unittest.mock to patch get_cache_backend so no external services
are needed. Verifies the generate -> verify round-trip, one-time-use
semantics, TTL propagation, and error handling.
"""
from unittest.mock import patch
import pytest
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.federated_connectors.oauth_utils import generate_oauth_state
from onyx.federated_connectors.oauth_utils import OAUTH_STATE_TTL
from onyx.federated_connectors.oauth_utils import OAuthSession
from onyx.federated_connectors.oauth_utils import verify_oauth_state
class _MemoryCacheBackend(CacheBackend):
"""Minimal in-memory CacheBackend for unit tests."""
def __init__(self) -> None:
self._store: dict[str, bytes] = {}
self.set_calls: list[dict[str, object]] = []
def get(self, key: str) -> bytes | None:
return self._store.get(key)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
self.set_calls.append({"key": key, "ex": ex})
if isinstance(value, bytes):
self._store[key] = value
else:
self._store[key] = str(value).encode()
def delete(self, key: str) -> None:
self._store.pop(key, None)
def exists(self, key: str) -> bool:
return key in self._store
def expire(self, key: str, seconds: int) -> None:
pass
def ttl(self, key: str) -> int:
return -2 if key not in self._store else -1
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
raise NotImplementedError
def rpush(self, key: str, value: str | bytes) -> None:
raise NotImplementedError
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
raise NotImplementedError
def _patched(cache: _MemoryCacheBackend): # type: ignore[no-untyped-def]
return patch(
"onyx.federated_connectors.oauth_utils.get_cache_backend",
return_value=cache,
)
class TestGenerateAndVerifyRoundTrip:
def test_round_trip_basic(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
state = generate_oauth_state(
federated_connector_id=42,
user_id="user-abc",
)
session = verify_oauth_state(state)
assert session.federated_connector_id == 42
assert session.user_id == "user-abc"
assert session.redirect_uri is None
assert session.additional_data == {}
def test_round_trip_with_all_fields(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
state = generate_oauth_state(
federated_connector_id=7,
user_id="user-xyz",
redirect_uri="https://example.com/callback",
additional_data={"scope": "read"},
)
session = verify_oauth_state(state)
assert session.federated_connector_id == 7
assert session.user_id == "user-xyz"
assert session.redirect_uri == "https://example.com/callback"
assert session.additional_data == {"scope": "read"}
class TestOneTimeUse:
def test_verify_deletes_state(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
state = generate_oauth_state(federated_connector_id=1, user_id="u")
verify_oauth_state(state)
with pytest.raises(ValueError, match="OAuth state not found"):
verify_oauth_state(state)
class TestTTLPropagation:
def test_default_ttl(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
generate_oauth_state(federated_connector_id=1, user_id="u")
assert len(cache.set_calls) == 1
assert cache.set_calls[0]["ex"] == OAUTH_STATE_TTL
def test_custom_ttl(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
generate_oauth_state(federated_connector_id=1, user_id="u", ttl=600)
assert cache.set_calls[0]["ex"] == 600
class TestVerifyInvalidState:
def test_missing_state_raises(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
state = generate_oauth_state(federated_connector_id=1, user_id="u")
# Manually clear the cache to simulate expiration
cache._store.clear()
with pytest.raises(ValueError, match="OAuth state not found"):
verify_oauth_state(state)
class TestOAuthSessionSerialization:
def test_to_dict_from_dict_round_trip(self) -> None:
session = OAuthSession(
federated_connector_id=5,
user_id="u-123",
redirect_uri="https://redir.example.com",
additional_data={"key": "val"},
)
d = session.to_dict()
restored = OAuthSession.from_dict(d)
assert restored.federated_connector_id == 5
assert restored.user_id == "u-123"
assert restored.redirect_uri == "https://redir.example.com"
assert restored.additional_data == {"key": "val"}
def test_from_dict_defaults(self) -> None:
minimal = {"federated_connector_id": 1, "user_id": "u"}
session = OAuthSession.from_dict(minimal)
assert session.redirect_uri is None
assert session.additional_data == {}

View File

@@ -117,7 +117,10 @@ class TestOktaProvider:
user = _make_mock_user(personal_name=None)
result = provider.build_user_resource(user, None)
assert result.name == ScimName(givenName="", familyName="", formatted="")
# Falls back to deriving name from email local part
assert result.name == ScimName(
givenName="test", familyName="", formatted="test"
)
assert result.displayName is None
def test_build_user_resource_scim_username_preserves_case(self) -> None:

View File

@@ -215,7 +215,7 @@ class TestCreateUser:
mock_dal.commit.assert_called_once()
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_missing_external_id_creates_user_without_mapping(
def test_missing_external_id_still_creates_mapping(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
@@ -223,6 +223,7 @@ class TestCreateUser:
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
"""Mapping is always created to mark user as SCIM-managed."""
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(externalId=None)
@@ -236,11 +237,11 @@ class TestCreateUser:
parsed = parse_scim_user(result, status=201)
assert parsed.userName is not None
mock_dal.add_user.assert_called_once()
mock_dal.create_user_mapping.assert_not_called()
mock_dal.create_user_mapping.assert_called_once()
mock_dal.commit.assert_called_once()
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_duplicate_email_returns_409(
def test_duplicate_scim_managed_email_returns_409(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
@@ -248,7 +249,12 @@ class TestCreateUser:
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
mock_dal.get_user_by_email.return_value = make_db_user()
"""409 only when the existing user already has a SCIM mapping."""
existing = make_db_user()
mock_dal.get_user_by_email.return_value = existing
mock_dal.get_user_mapping_by_user_id.return_value = make_user_mapping(
user_id=existing.id
)
resource = make_scim_user()
result = create_user(
@@ -260,6 +266,40 @@ class TestCreateUser:
assert_scim_error(result, 409)
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_existing_user_without_mapping_gets_linked(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
"""Pre-existing user without SCIM mapping gets adopted (linked)."""
existing = make_db_user(email="admin@example.com", personal_name=None)
mock_dal.get_user_by_email.return_value = existing
mock_dal.get_user_mapping_by_user_id.return_value = None
resource = make_scim_user(userName="admin@example.com", externalId="ext-admin")
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
parsed = parse_scim_user(result, status=201)
assert parsed.userName == "admin@example.com"
# Should NOT create a new user — reuse existing
mock_dal.add_user.assert_not_called()
# Should sync is_active and personal_name from the SCIM request
mock_dal.update_user.assert_called_once_with(
existing, is_active=True, personal_name="Test User"
)
# Should create a SCIM mapping for the existing user
mock_dal.create_user_mapping.assert_called_once()
mock_dal.commit.assert_called_once()
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_integrity_error_returns_409(
self,

View File

@@ -1,15 +1,17 @@
"use client";
import { SvgMcp } from "@opal/icons";
import MCPPageContent from "@/sections/actions/MCPPageContent";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.MCP_ACTIONS]!;
export default function Main() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgMcp}
title="MCP Actions"
icon={route.icon}
title={route.title}
description="Connect MCP (Model Context Protocol) servers to add custom actions and tools for your agents."
separator
/>

View File

@@ -1,15 +1,17 @@
"use client";
import { SvgActions } from "@opal/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import OpenApiPageContent from "@/sections/actions/OpenApiPageContent";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.OPENAPI_ACTIONS]!;
export default function Main() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgActions}
title="OpenAPI Actions"
icon={route.icon}
title={route.title}
description="Connect OpenAPI servers to add custom actions and tools for your agents."
separator
/>

View File

@@ -1,5 +1,5 @@
"use client";
import { AdminPageTitle } from "@/components/admin/Title";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SourceCategory, SourceMetadata } from "@/lib/search/interfaces";
import { listSourceMetadata } from "@/lib/sources";
import Button from "@/refresh-components/buttons/Button";
@@ -32,7 +32,7 @@ import { SettingsContext } from "@/providers/SettingsProvider";
import SourceTile from "@/components/SourceTile";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import Text from "@/refresh-components/texts/Text";
import { SvgUploadCloud } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
function SourceTileTooltipWrapper({
sourceMetadata,
preSelect,
@@ -124,6 +124,7 @@ function SourceTileTooltipWrapper({
}
export default function Page() {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.ADD_CONNECTOR]!;
const sources = useMemo(() => listSourceMetadata(), []);
const [rawSearchTerm, setSearchTerm] = useState("");
@@ -248,61 +249,37 @@ export default function Page() {
};
return (
<>
<AdminPageTitle
icon={SvgUploadCloud}
title="Add Connector"
farRightElement={
<SettingsLayouts.Root width="full">
<SettingsLayouts.Header
icon={route.icon}
title={route.title}
rightChildren={
<Button href="/admin/indexing/status" primary>
See Connectors
</Button>
}
separator
/>
<SettingsLayouts.Body>
<InputTypeIn
type="text"
placeholder="Search Connectors"
ref={searchInputRef}
value={rawSearchTerm} // keep the input bound to immediate state
onChange={(event) => setSearchTerm(event.target.value)}
onKeyDown={handleKeyPress}
className="w-96 flex-none"
/>
<InputTypeIn
type="text"
placeholder="Search Connectors"
ref={searchInputRef}
value={rawSearchTerm} // keep the input bound to immediate state
onChange={(event) => setSearchTerm(event.target.value)}
onKeyDown={handleKeyPress}
className="w-96 flex-none"
/>
{dedupedPopular.length > 0 && (
<div className="pt-8">
<Text as="p" headingH3>
Popular
</Text>
<div className="flex flex-wrap gap-4 p-4">
{dedupedPopular.map((source) => (
<SourceTileTooltipWrapper
preSelect={false}
key={source.internalName}
sourceMetadata={source}
federatedConnectors={federatedConnectors}
slackCredentials={slackCredentials}
/>
))}
</div>
</div>
)}
{Object.entries(categorizedSources)
.filter(([_, sources]) => sources.length > 0)
.map(([category, sources], categoryInd) => (
<div key={category} className="pt-8">
{dedupedPopular.length > 0 && (
<div className="pt-8">
<Text as="p" headingH3>
{category}
Popular
</Text>
<div className="flex flex-wrap gap-4 p-4">
{sources.map((source, sourceInd) => (
{dedupedPopular.map((source) => (
<SourceTileTooltipWrapper
preSelect={
(searchTerm?.length ?? 0) > 0 &&
categoryInd == 0 &&
sourceInd == 0
}
preSelect={false}
key={source.internalName}
sourceMetadata={source}
federatedConnectors={federatedConnectors}
@@ -311,7 +288,33 @@ export default function Page() {
))}
</div>
</div>
))}
</>
)}
{Object.entries(categorizedSources)
.filter(([_, sources]) => sources.length > 0)
.map(([category, sources], categoryInd) => (
<div key={category} className="pt-8">
<Text as="p" headingH3>
{category}
</Text>
<div className="flex flex-wrap gap-4 p-4">
{sources.map((source, sourceInd) => (
<SourceTileTooltipWrapper
preSelect={
(searchTerm?.length ?? 0) > 0 &&
categoryInd == 0 &&
sourceInd == 0
}
key={source.internalName}
sourceMetadata={source}
federatedConnectors={federatedConnectors}
slackCredentials={slackCredentials}
/>
))}
</div>
</div>
))}
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -4,14 +4,14 @@ import { PersonasTable } from "./PersonaTable";
import Text from "@/components/ui/text";
import Title from "@/components/ui/title";
import Separator from "@/refresh-components/Separator";
import { AdminPageTitle } from "@/components/admin/Title";
import { SubLabel } from "@/components/Field";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { useAdminPersonas } from "@/hooks/useAdminPersonas";
import { Persona } from "./interfaces";
import { ThreeDotsLoader } from "@/components/Loading";
import { ErrorCallout } from "@/components/ErrorCallout";
import { SvgOnyxOctagon } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import { useState, useEffect } from "react";
import Pagination from "@/refresh-components/Pagination";
@@ -120,6 +120,7 @@ function MainContent({
}
export default function Page() {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.AGENTS]!;
const [currentPage, setCurrentPage] = useState(1);
const { personas, totalItems, isLoading, error, refresh } = useAdminPersonas({
pageNum: currentPage - 1, // Backend uses 0-indexed pages
@@ -127,31 +128,33 @@ export default function Page() {
});
return (
<>
<AdminPageTitle icon={SvgOnyxOctagon} title="Agents" />
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
{isLoading && <ThreeDotsLoader />}
<SettingsLayouts.Body>
{isLoading && <ThreeDotsLoader />}
{error && (
<ErrorCallout
errorTitle="Failed to load agents"
errorMsg={
error?.info?.message ||
error?.info?.detail ||
"An unknown error occurred"
}
/>
)}
{error && (
<ErrorCallout
errorTitle="Failed to load agents"
errorMsg={
error?.info?.message ||
error?.info?.detail ||
"An unknown error occurred"
}
/>
)}
{!isLoading && !error && (
<MainContent
personas={personas}
totalItems={totalItems}
currentPage={currentPage}
onPageChange={setCurrentPage}
refreshPersonas={refresh}
/>
)}
</>
{!isLoading && !error && (
<MainContent
personas={personas}
totalItems={totalItems}
currentPage={currentPage}
onPageChange={setCurrentPage}
refreshPersonas={refresh}
/>
)}
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,8 +1,8 @@
"use client";
import { ThreeDotsLoader } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import { errorHandlingFetcher } from "@/lib/fetcher";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ErrorCallout } from "@/components/ErrorCallout";
import useSWR, { mutate } from "swr";
import Separator from "@/refresh-components/Separator";
@@ -32,6 +32,9 @@ import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import Text from "@/refresh-components/texts/Text";
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.API_KEYS]!;
function Main() {
const {
@@ -233,10 +236,11 @@ function Main() {
export default function Page() {
return (
<>
<AdminPageTitle title="API Keys" icon={SvgKey} />
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -4,9 +4,8 @@ import CardSection from "@/components/admin/CardSection";
import { useRouter } from "next/navigation";
import { useState } from "react";
import { SlackTokensForm } from "./SlackTokensForm";
import { SourceIcon } from "@/components/SourceIcon";
import { AdminPageTitle } from "@/components/admin/Title";
import { ValidSources } from "@/lib/types";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SvgSlack } from "@opal/icons";
export const NewSlackBotForm = () => {
const [formValues] = useState({
@@ -19,20 +18,19 @@ export const NewSlackBotForm = () => {
const router = useRouter();
return (
<div>
<AdminPageTitle
icon={<SourceIcon iconSize={36} sourceType={ValidSources.Slack} />}
title="New Slack Bot"
/>
<CardSection>
<div className="p-4">
<SlackTokensForm
isUpdate={false}
initialValues={formValues}
router={router}
/>
</div>
</CardSection>
</div>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={SvgSlack} title="New Slack Bot" separator />
<SettingsLayouts.Body>
<CardSection>
<div className="p-4">
<SlackTokensForm
isUpdate={false}
initialValues={formValues}
router={router}
/>
</div>
</CardSection>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};

View File

@@ -1,15 +1,10 @@
import { AdminPageTitle } from "@/components/admin/Title";
import { SourceIcon } from "@/components/SourceIcon";
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
import { fetchSS } from "@/lib/utilsSS";
import { ErrorCallout } from "@/components/ErrorCallout";
import {
DocumentSetSummary,
SlackChannelConfig,
ValidSources,
} from "@/lib/types";
import BackButton from "@/refresh-components/buttons/BackButton";
import { DocumentSetSummary, SlackChannelConfig } from "@/lib/types";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SvgSlack } from "@opal/icons";
import { FetchAgentsResponse, fetchAgentsSS } from "@/lib/agentsSS";
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
@@ -77,27 +72,28 @@ async function EditslackChannelConfigPage(props: {
}
return (
<div className="max-w-4xl container">
<SettingsLayouts.Root>
<InstantSSRAutoRefresh />
<BackButton />
<AdminPageTitle
icon={<SourceIcon sourceType={ValidSources.Slack} iconSize={32} />}
<SettingsLayouts.Header
icon={SvgSlack}
title={
slackChannelConfig.is_default
? "Edit Default Slack Config"
: "Edit Slack Channel Config"
}
separator
backButton
/>
<SlackChannelConfigCreationForm
slack_bot_id={slackChannelConfig.slack_bot_id}
documentSets={documentSets}
personas={assistants}
standardAnswerCategoryResponse={eeStandardAnswerCategoryResponse}
existingSlackChannelConfig={slackChannelConfig}
/>
</div>
<SettingsLayouts.Body>
<SlackChannelConfigCreationForm
slack_bot_id={slackChannelConfig.slack_bot_id}
documentSets={documentSets}
personas={assistants}
standardAnswerCategoryResponse={eeStandardAnswerCategoryResponse}
existingSlackChannelConfig={slackChannelConfig}
/>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,13 +1,12 @@
import { AdminPageTitle } from "@/components/admin/Title";
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
import { fetchSS } from "@/lib/utilsSS";
import { ErrorCallout } from "@/components/ErrorCallout";
import { DocumentSetSummary, ValidSources } from "@/lib/types";
import BackButton from "@/refresh-components/buttons/BackButton";
import { DocumentSetSummary } from "@/lib/types";
import { fetchAgentsSS } from "@/lib/agentsSS";
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
import { redirect } from "next/navigation";
import { SourceIcon } from "@/components/SourceIcon";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SvgSlack } from "@opal/icons";
async function NewChannelConfigPage(props: {
params: Promise<{ "bot-id": string }>;
@@ -50,20 +49,22 @@ async function NewChannelConfigPage(props: {
}
return (
<>
<BackButton />
<AdminPageTitle
icon={<SourceIcon iconSize={32} sourceType={ValidSources.Slack} />}
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgSlack}
title="Configure OnyxBot for Slack Channel"
separator
backButton
/>
<SlackChannelConfigCreationForm
slack_bot_id={slack_bot_id}
documentSets={documentSets}
personas={agentsResponse[0]}
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
/>
</>
<SettingsLayouts.Body>
<SlackChannelConfigCreationForm
slack_bot_id={slack_bot_id}
documentSets={documentSets}
personas={agentsResponse[0]}
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
/>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -3,15 +3,14 @@
import { ErrorCallout } from "@/components/ErrorCallout";
import { ThreeDotsLoader } from "@/components/Loading";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { AdminPageTitle } from "@/components/admin/Title";
import { SourceIcon } from "@/components/SourceIcon";
import { SlackBotTable } from "./SlackBotTable";
import { useSlackBots } from "./[bot-id]/hooks";
import { ValidSources } from "@/lib/types";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { DOCS_ADMINS_PATH } from "@/lib/constants";
const Main = () => {
function Main() {
const {
data: slackBots,
isLoading: isSlackBotsLoading,
@@ -73,20 +72,18 @@ const Main = () => {
<SlackBotTable slackBots={slackBots} />
</div>
);
};
}
export default function Page() {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.SLACK_BOTS]!;
const Page = () => {
return (
<>
<AdminPageTitle
icon={<SourceIcon iconSize={36} sourceType={ValidSources.Slack} />}
title="Slack Bots"
/>
<InstantSSRAutoRefresh />
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<InstantSSRAutoRefresh />
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
export default Page;
}

View File

@@ -4,13 +4,15 @@ import { useState } from "react";
import CardSection from "@/components/admin/CardSection";
import Button from "@/refresh-components/buttons/Button";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { DocumentIcon2 } from "@/components/icons/icons";
import useSWR from "swr";
import { ThreeDotsLoader } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { SvgLock } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_PROCESSING]!;
function Main() {
const {
@@ -149,12 +151,11 @@ function Main() {
export default function Page() {
return (
<>
<AdminPageTitle
title="Document Processing"
icon={<DocumentIcon2 size={32} className="my-auto" />}
/>
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,15 +1,17 @@
"use client";
import { SvgImage } from "@opal/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import ImageGenerationContent from "./ImageGenerationContent";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.IMAGE_GENERATION]!;
export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgImage}
title="Image Generation"
icon={route.icon}
title={route.title}
description="Settings for in-chat image generation."
/>
<SettingsLayouts.Body>

View File

@@ -1,8 +1,8 @@
"use client";
import { ThreeDotsLoader } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import { errorHandlingFetcher } from "@/lib/fetcher";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/components/ui/text";
import Title from "@/components/ui/title";
import Button from "@/refresh-components/buttons/Button";
@@ -19,7 +19,10 @@ import { SettingsContext } from "@/providers/SettingsProvider";
import CardSection from "@/components/admin/CardSection";
import { ErrorCallout } from "@/components/ErrorCallout";
import { useToastFromQuery } from "@/hooks/useToast";
import { SvgSearch } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.SEARCH_SETTINGS]!;
export interface EmbeddingDetails {
api_key: string;
custom_config: any;
@@ -141,9 +144,11 @@ function Main() {
export default function Page() {
return (
<>
<AdminPageTitle title="Search Settings" icon={SvgSearch} />
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -22,7 +22,10 @@ import {
SvgOnyxLogo,
SvgX,
} from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.WEB_SEARCH]!;
import {
SEARCH_PROVIDERS_URL,
SEARCH_PROVIDER_DETAILS,
@@ -403,8 +406,8 @@ export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgGlobe}
title="Web Search"
icon={route.icon}
title={route.title}
description="Search settings for external search across the internet."
separator
/>
@@ -426,8 +429,8 @@ export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgGlobe}
title="Web Search"
icon={route.icon}
title={route.title}
description="Search settings for external search across the internet."
separator
/>
@@ -832,8 +835,8 @@ export default function Page() {
<>
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgGlobe}
title="Web Search"
icon={route.icon}
title={route.title}
description="Search settings for external search across the internet."
separator
/>

View File

@@ -1,8 +1,7 @@
"use client";
import { useState, useEffect } from "react";
import { AdminPageTitle } from "@/components/admin/Title";
import { FiDownload } from "react-icons/fi";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ThreeDotsLoader } from "@/components/Loading";
import {
Table,
@@ -17,6 +16,10 @@ import { Card } from "@/components/ui/card";
import Text from "@/components/ui/text";
import { Spinner } from "@/components/Spinner";
import { SvgDownloadCloud } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DEBUG]!;
function Main() {
const [categories, setCategories] = useState<string[]>([]);
const [isLoading, setIsLoading] = useState(true);
@@ -114,13 +117,13 @@ function Main() {
);
}
const Page = () => {
export default function Page() {
return (
<>
<AdminPageTitle icon={<FiDownload size={32} />} title="Debug Logs" />
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
export default Page;
}

View File

@@ -19,7 +19,7 @@ import {
import { createGuildConfig } from "@/app/admin/discord-bot/lib";
import { DiscordGuildsTable } from "@/app/admin/discord-bot/DiscordGuildsTable";
import { BotConfigCard } from "@/app/admin/discord-bot/BotConfigCard";
import { SvgDiscordMono } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
function DiscordBotContent() {
const { data: guilds, isLoading, error, refreshGuilds } = useDiscordGuilds();
@@ -118,11 +118,13 @@ function DiscordBotContent() {
}
export default function Page() {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DISCORD_BOTS]!;
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgDiscordMono}
title="Discord Bots"
icon={route.icon}
title={route.title}
description="Connect Onyx to your Discord servers. Users can ask questions directly in Discord channels."
/>
<SettingsLayouts.Body>

View File

@@ -2,8 +2,11 @@
import { useState } from "react";
import useSWR from "swr";
import { SvgArrowExchange } from "@opal/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.INDEX_MIGRATION]!;
import Card from "@/refresh-components/cards/Card";
import { Content, ContentAction } from "@opal/layouts";
import Text from "@/refresh-components/texts/Text";
@@ -213,8 +216,8 @@ export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgArrowExchange}
title="Document Index Migration"
icon={route.icon}
title={route.title}
description="Monitor the migration from Vespa to OpenSearch and control the active retrieval source."
separator
/>

View File

@@ -0,0 +1,35 @@
"use client";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import { Explorer } from "./Explorer";
import { Connector } from "@/lib/connectors/connectors";
import { DocumentSetSummary } from "@/lib/types";
interface DocumentExplorerPageProps {
initialSearchValue: string | undefined;
connectors: Connector<any>[];
documentSets: DocumentSetSummary[];
}
export default function DocumentExplorerPage({
initialSearchValue,
connectors,
documentSets,
}: DocumentExplorerPageProps) {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_EXPLORER]!;
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Explorer
initialSearchValue={initialSearchValue}
connectors={connectors}
documentSets={documentSets}
/>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,7 +1,6 @@
import { AdminPageTitle } from "@/components/admin/Title";
import { Explorer } from "./Explorer";
import { fetchValidFilterInfo } from "@/lib/search/utilsSS";
import { SvgZoomIn } from "@opal/icons";
import DocumentExplorerPage from "./DocumentExplorerPage";
export default async function Page(props: {
searchParams: Promise<{ [key: string]: string }>;
}) {
@@ -9,17 +8,10 @@ export default async function Page(props: {
const { connectors, documentSets } = await fetchValidFilterInfo();
return (
<>
<AdminPageTitle
icon={<SvgZoomIn className="stroke-text-04 h-8 w-8" />}
title="Document Explorer"
/>
<Explorer
initialSearchValue={searchParams.query}
connectors={connectors}
documentSets={documentSets}
/>
</>
<DocumentExplorerPage
initialSearchValue={searchParams.query}
connectors={connectors}
documentSets={documentSets}
/>
);
}

View File

@@ -4,10 +4,11 @@ import { LoadingAnimation } from "@/components/Loading";
import { useMostReactedToDocuments } from "@/lib/hooks";
import { DocumentFeedbackTable } from "./DocumentFeedbackTable";
import { numPages, numToDisplay } from "./constants";
import { AdminPageTitle } from "@/components/admin/Title";
import Title from "@/components/ui/title";
import { SvgThumbsUp } from "@opal/icons";
const Main = () => {
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
function Main() {
const {
data: mostLikedDocuments,
isLoading: isMostLikedDocumentsLoading,
@@ -57,16 +58,17 @@ const Main = () => {
/>
</div>
);
};
}
export default function Page() {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_FEEDBACK]!;
const Page = () => {
return (
<>
<AdminPageTitle icon={SvgThumbsUp} title="Document Feedback" />
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
export default Page;
}

View File

@@ -5,9 +5,8 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import { refreshDocumentSets, useDocumentSets } from "../hooks";
import { useConnectorStatus, useUserGroups } from "@/lib/hooks";
import { ThreeDotsLoader } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import { BookmarkIcon } from "@/components/icons/icons";
import BackButton from "@/refresh-components/buttons/BackButton";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import CardSection from "@/components/admin/CardSection";
import { DocumentSetCreationForm } from "../DocumentSetCreationForm";
import { useRouter } from "next/navigation";
@@ -69,24 +68,17 @@ function Main({ documentSetId }: { documentSetId: number }) {
}
return (
<div>
<AdminPageTitle
icon={<BookmarkIcon size={32} />}
title={documentSet.name}
<CardSection>
<DocumentSetCreationForm
ccPairs={ccPairs}
userGroups={userGroups}
onClose={() => {
refreshDocumentSets();
router.push("/admin/documents/sets");
}}
existingDocumentSet={documentSet}
/>
<CardSection>
<DocumentSetCreationForm
ccPairs={ccPairs}
userGroups={userGroups}
onClose={() => {
refreshDocumentSets();
router.push("/admin/documents/sets");
}}
existingDocumentSet={documentSet}
/>
</CardSection>
</div>
</CardSection>
);
}
@@ -95,12 +87,19 @@ export default function Page(props: {
}) {
const params = use(props.params);
const documentSetId = parseInt(params.documentSetId);
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_SETS]!;
return (
<>
<BackButton />
<Main documentSetId={documentSetId} />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title="Edit Document Set"
separator
backButton
/>
<SettingsLayouts.Body>
<Main documentSetId={documentSetId} />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,11 +1,10 @@
"use client";
import { AdminPageTitle } from "@/components/admin/Title";
import { BookmarkIcon } from "@/components/icons/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import { DocumentSetCreationForm } from "../DocumentSetCreationForm";
import { useConnectorStatus, useUserGroups } from "@/lib/hooks";
import { ThreeDotsLoader } from "@/components/Loading";
import BackButton from "@/refresh-components/buttons/BackButton";
import { ErrorCallout } from "@/components/ErrorCallout";
import { useRouter } from "next/navigation";
import { refreshDocumentSets } from "../hooks";
@@ -56,19 +55,20 @@ function Main() {
);
}
const Page = () => {
export default function Page() {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_SETS]!;
return (
<>
<BackButton />
<AdminPageTitle
icon={<BookmarkIcon size={32} />}
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title="New Document Set"
separator
backButton
/>
<Main />
</>
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
export default Page;
}

View File

@@ -2,7 +2,7 @@
import { ThreeDotsLoader } from "@/components/Loading";
import { PageSelector } from "@/components/PageSelector";
import { BookmarkIcon, InfoIcon } from "@/components/icons/icons";
import { InfoIcon } from "@/components/icons/icons";
import {
Table,
TableHead,
@@ -19,7 +19,8 @@ import { useDocumentSets } from "./hooks";
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
import { deleteDocumentSet } from "./lib";
import { toast } from "@/hooks/useToast";
import { AdminPageTitle } from "@/components/admin/Title";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import {
FiAlertTriangle,
FiCheckCircle,
@@ -358,7 +359,7 @@ const DocumentSetTable = ({
);
};
const Main = () => {
function Main() {
const {
data: documentSets,
isLoading: isDocumentSetsLoading,
@@ -418,16 +419,17 @@ const Main = () => {
)}
</div>
);
};
}
export default function Page() {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.DOCUMENT_SETS]!;
const Page = () => {
return (
<>
<AdminPageTitle icon={<BookmarkIcon size={32} />} title="Document Sets" />
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
export default Page;
}

View File

@@ -165,7 +165,7 @@ function ConnectorRow({
onClick={handleRowClick}
>
<TableCell className="">
<p className="lg:w-[200px] xl:w-[400px] inline-block ellipsis truncate">
<p className="max-w-[200px] xl:max-w-[400px] inline-block ellipsis truncate">
{ccPairsIndexingStatus.name}
</p>
</TableCell>
@@ -246,7 +246,7 @@ function FederatedConnectorRow({
onClick={handleRowClick}
>
<TableCell className="">
<p className="lg:w-[200px] xl:w-[400px] inline-block ellipsis truncate">
<p className="max-w-[200px] xl:max-w-[400px] inline-block ellipsis truncate">
{federatedConnector.name}
</p>
</TableCell>
@@ -293,7 +293,7 @@ export function CCPairIndexingStatusTable({
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
return (
<Table className="-mt-8">
<Table className="-mt-8 table-fixed">
<TableHeader>
<ConnectorRow
invisible

View File

@@ -1,10 +1,10 @@
"use client";
import { NotebookIcon } from "@/components/icons/icons";
import { CCPairIndexingStatusTable } from "./CCPairIndexingStatusTable";
import { SearchAndFilterControls } from "./SearchAndFilterControls";
import { AdminPageTitle } from "@/components/admin/Title";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Link from "next/link";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import Text from "@/components/ui/text";
import { useConnectorIndexingStatusWithPagination } from "@/lib/hooks";
import { useToastFromQuery } from "@/hooks/useToast";
@@ -201,6 +201,8 @@ function Main() {
}
export default function Status() {
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.INDEXING_STATUS]!;
useToastFromQuery({
"connector-created": {
message: "Connector created successfully",
@@ -213,16 +215,18 @@ export default function Status() {
});
return (
<>
<AdminPageTitle
icon={<NotebookIcon size={32} />}
title="Existing Connectors"
farRightElement={
<SettingsLayouts.Root width="full">
<SettingsLayouts.Header
icon={route.icon}
title={route.title}
rightChildren={
<Button href="/admin/add-connector">Add Connector</Button>
}
separator
/>
<Main />
</>
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,14 +1,13 @@
"use client";
import CardSection from "@/components/admin/CardSection";
import { AdminPageTitle } from "@/components/admin/Title";
import {
DatePickerField,
FieldLabel,
TextArrayField,
TextFormField,
} from "@/components/Field";
import { BrainIcon } from "@/components/icons/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Modal from "@/refresh-components/Modal";
import Button from "@/refresh-components/buttons/Button";
import SwitchField from "@/refresh-components/form/SwitchField";
@@ -31,6 +30,9 @@ import KGEntityTypes from "@/app/admin/kg/KGEntityTypes";
import Text from "@/refresh-components/texts/Text";
import { cn } from "@/lib/utils";
import { SvgSettings } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.KNOWLEDGE_GRAPH]!;
function createDomainField(
name: string,
@@ -324,12 +326,11 @@ export default function Page() {
}
return (
<>
<AdminPageTitle
title="Knowledge Graph"
icon={<BrainIcon size={32} className="my-auto" />}
/>
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,7 +1,7 @@
"use client";
import { AdminPageTitle } from "@/components/admin/Title";
import SimpleTabs from "@/refresh-components/SimpleTabs";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/components/ui/text";
import { useState } from "react";
import {
@@ -16,8 +16,11 @@ import { toast } from "@/hooks/useToast";
import CreateRateLimitModal from "./CreateRateLimitModal";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { SvgGlobe, SvgShield, SvgUser, SvgUsers } from "@opal/icons";
import { SvgGlobe, SvgUser, SvgUsers } from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.TOKEN_RATE_LIMITS]!;
const BASE_URL = "/api/admin/token-rate-limits";
const GLOBAL_TOKEN_FETCH_URL = `${BASE_URL}/global`;
const USER_TOKEN_FETCH_URL = `${BASE_URL}/users`;
@@ -208,9 +211,11 @@ function Main() {
export default function Page() {
return (
<>
<AdminPageTitle title="Token Rate Limits" icon={SvgShield} />
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -5,11 +5,10 @@ import SimpleTabs from "@/refresh-components/SimpleTabs";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import InvitedUserTable from "@/components/admin/users/InvitedUserTable";
import SignedUpUserTable from "@/components/admin/users/SignedUpUserTable";
import Modal from "@/refresh-components/Modal";
import { ThreeDotsLoader } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import { toast } from "@/hooks/useToast";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr";
import { ErrorCallout } from "@/components/ErrorCallout";
@@ -22,7 +21,11 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
import Button from "@/refresh-components/buttons/Button";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { Spinner } from "@/components/Spinner";
import { SvgDownloadCloud, SvgUser, SvgUserPlus } from "@opal/icons";
import { SvgDownloadCloud, SvgUserPlus } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.USERS]!;
interface CountDisplayProps {
label: string;
value: number | null;
@@ -48,7 +51,7 @@ function CountDisplay({ label, value, isLoading }: CountDisplayProps) {
);
}
const UsersTables = ({
function UsersTables({
q,
isDownloadingUsers,
setIsDownloadingUsers,
@@ -56,7 +59,7 @@ const UsersTables = ({
q: string;
isDownloadingUsers: boolean;
setIsDownloadingUsers: (loading: boolean) => void;
}) => {
}) {
const [currentUsersCount, setCurrentUsersCount] = useState<number | null>(
null
);
@@ -236,9 +239,9 @@ const UsersTables = ({
});
return <SimpleTabs tabs={tabs} defaultValue="current" />;
};
}
const SearchableTables = () => {
function SearchableTables() {
const [query, setQuery] = useState("");
const [isDownloadingUsers, setIsDownloadingUsers] = useState(false);
@@ -262,7 +265,7 @@ const SearchableTables = () => {
</div>
</div>
);
};
}
function AddUserButton() {
const [bulkAddUsersModal, setBulkAddUsersModal] = useState(false);
@@ -325,13 +328,13 @@ function AddUserButton() {
);
}
const Page = () => {
export default function Page() {
return (
<>
<AdminPageTitle title="Manage Users" icon={SvgUser} />
<SearchableTables />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header title={route.title} icon={route.icon} separator />
<SettingsLayouts.Body>
<SearchableTables />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
export default Page;
}

View File

@@ -1,6 +1,6 @@
import { AdminPageTitle } from "@/components/admin/Title";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import BillingInformationPage from "./BillingInformationPage";
import { MdOutlineCreditCard } from "react-icons/md";
import { SvgCreditCard } from "@opal/icons";
export interface BillingInformation {
stripe_subscription_id: string;
@@ -18,12 +18,15 @@ export interface BillingInformation {
export default function page() {
return (
<>
<AdminPageTitle
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgCreditCard}
title="Billing Information"
icon={<MdOutlineCreditCard size={32} className="my-auto" />}
separator
/>
<BillingInformationPage />
</>
<SettingsLayouts.Body>
<BillingInformationPage />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,25 +1,23 @@
"use client";
import { use } from "react";
import { use } from "react";
import { GroupDisplay } from "./GroupDisplay";
import { useSpecificUserGroup } from "./hook";
import { ThreeDotsLoader } from "@/components/Loading";
import { useConnectorStatus } from "@/lib/hooks";
import { useRouter } from "next/navigation";
import useUsers from "@/hooks/useUsers";
import BackButton from "@/refresh-components/buttons/BackButton";
import { AdminPageTitle } from "@/components/admin/Title";
import { SvgUsers } from "@opal/icons";
const Page = (props: { params: Promise<{ groupId: string }> }) => {
const params = use(props.params);
const router = useRouter();
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import * as SettingsLayouts from "@/layouts/settings-layouts";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.GROUPS]!;
function Main({ groupId }: { groupId: string }) {
const {
userGroup,
isLoading: userGroupIsLoading,
error: userGroupError,
refreshUserGroup,
} = useSpecificUserGroup(params.groupId);
} = useSpecificUserGroup(groupId);
const {
data: users,
isLoading: userIsLoading,
@@ -53,22 +51,31 @@ const Page = (props: { params: Promise<{ groupId: string }> }) => {
return (
<>
<BackButton />
<SettingsLayouts.Header
icon={route.icon}
title={userGroup.name || "Unknown"}
separator
backButton
/>
<AdminPageTitle title={userGroup.name || "Unknown"} icon={SvgUsers} />
{userGroup ? (
<SettingsLayouts.Body>
<GroupDisplay
users={users.accepted}
ccPairs={ccPairs}
userGroup={userGroup}
refreshUserGroup={refreshUserGroup}
/>
) : (
<div>Unable to fetch User Group :(</div>
)}
</SettingsLayouts.Body>
</>
);
};
}
export default Page;
export default function Page(props: { params: Promise<{ groupId: string }> }) {
const params = use(props.params);
return (
<SettingsLayouts.Root>
<Main groupId={params.groupId} />
</SettingsLayouts.Root>
);
}

View File

@@ -5,13 +5,15 @@ import UserGroupCreationForm from "./UserGroupCreationForm";
import { useState } from "react";
import { ThreeDotsLoader } from "@/components/Loading";
import { useConnectorStatus, useUserGroups } from "@/lib/hooks";
import { AdminPageTitle } from "@/components/admin/Title";
import useUsers from "@/hooks/useUsers";
import { useUser } from "@/providers/UserProvider";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { SvgUsers } from "@opal/icons";
const Main = () => {
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import * as SettingsLayouts from "@/layouts/settings-layouts";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.GROUPS]!;
function Main() {
const [showForm, setShowForm] = useState(false);
const { data, isLoading, error, refreshUserGroups } = useUserGroups();
@@ -70,16 +72,16 @@ const Main = () => {
)}
</>
);
};
}
const Page = () => {
export default function Page() {
return (
<>
<AdminPageTitle title="Manage User Groups" icon={SvgUsers} />
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<Main />
</>
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
export default Page;
}

View File

@@ -1,10 +1,12 @@
import { AdminPageTitle } from "@/components/admin/Title";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { CUSTOM_ANALYTICS_ENABLED } from "@/lib/constants";
import { Callout } from "@/components/ui/callout";
import { FiBarChart2 } from "react-icons/fi";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import Text from "@/components/ui/text";
import { CustomAnalyticsUpdateForm } from "./CustomAnalyticsUpdateForm";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.CUSTOM_ANALYTICS]!;
function Main() {
if (!CUSTOM_ANALYTICS_ENABLED) {
return (
@@ -35,13 +37,11 @@ function Main() {
export default function Page() {
return (
<main className="pt-4 mx-auto container">
<AdminPageTitle
title="Custom Analytics"
icon={<FiBarChart2 size={32} />}
/>
<Main />
</main>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,14 +1,19 @@
"use client";
import { AdminPageTitle } from "@/components/admin/Title";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { QueryHistoryTable } from "@/app/ee/admin/performance/query-history/QueryHistoryTable";
import { SvgServer } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.QUERY_HISTORY]!;
export default function QueryHistoryPage() {
return (
<>
<AdminPageTitle title="Query History" icon={SvgServer} />
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<QueryHistoryTable />
</>
<SettingsLayouts.Body>
<QueryHistoryTable />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -6,32 +6,36 @@ import { FeedbackChart } from "@/app/ee/admin/performance/usage/FeedbackChart";
import { QueryPerformanceChart } from "@/app/ee/admin/performance/usage/QueryPerformanceChart";
import { PersonaMessagesChart } from "@/app/ee/admin/performance/usage/PersonaMessagesChart";
import { useTimeRange } from "@/app/ee/admin/performance/lib";
import { AdminPageTitle } from "@/components/admin/Title";
import UsageReports from "@/app/ee/admin/performance/usage/UsageReports";
import Separator from "@/refresh-components/Separator";
import { useAdminPersonas } from "@/hooks/useAdminPersonas";
import { SvgActivity } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import * as SettingsLayouts from "@/layouts/settings-layouts";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.USAGE]!;
export default function AnalyticsPage() {
const [timeRange, setTimeRange] = useTimeRange();
const { personas } = useAdminPersonas();
return (
<>
<AdminPageTitle title="Usage Statistics" icon={SvgActivity} />
<AdminDateRangeSelector
value={timeRange}
onValueChange={(value) => setTimeRange(value as any)}
/>
<QueryPerformanceChart timeRange={timeRange} />
<FeedbackChart timeRange={timeRange} />
<OnyxBotChart timeRange={timeRange} />
<PersonaMessagesChart
availablePersonas={personas}
timeRange={timeRange}
/>
<Separator />
<UsageReports />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<AdminDateRangeSelector
value={timeRange}
onValueChange={(value) => setTimeRange(value as any)}
/>
<QueryPerformanceChart timeRange={timeRange} />
<FeedbackChart timeRange={timeRange} />
<OnyxBotChart timeRange={timeRange} />
<PersonaMessagesChart
availablePersonas={personas}
timeRange={timeRange}
/>
<Separator />
<UsageReports />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,13 +1,13 @@
import { AdminPageTitle } from "@/components/admin/Title";
import { StandardAnswerCreationForm } from "@/app/ee/admin/standard-answer/StandardAnswerCreationForm";
import { fetchSS } from "@/lib/utilsSS";
import { ErrorCallout } from "@/components/ErrorCallout";
import BackButton from "@/refresh-components/buttons/BackButton";
import { ClipboardIcon } from "@/components/icons/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import { StandardAnswer, StandardAnswerCategory } from "@/lib/types";
async function Page(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.STANDARD_ANSWERS]!;
async function Main({ id }: { id: string }) {
const tasks = [
fetchSS("/manage/admin/standard-answer"),
fetchSS(`/manage/admin/standard-answer/category`),
@@ -35,14 +35,14 @@ async function Page(props: { params: Promise<{ id: string }> }) {
const allStandardAnswers =
(await standardAnswersResponse.json()) as StandardAnswer[];
const standardAnswer = allStandardAnswers.find(
(answer) => answer.id.toString() === params.id
(answer) => answer.id.toString() === id
);
if (!standardAnswer) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Did not find standard answer with ID: ${params.id}`}
errorMsg={`Did not find standard answer with ID: ${id}`}
/>
);
}
@@ -67,20 +67,29 @@ async function Page(props: { params: Promise<{ id: string }> }) {
const standardAnswerCategories =
(await standardAnswerCategoriesResponse.json()) as StandardAnswerCategory[];
return (
<>
<BackButton />
<AdminPageTitle
title="Edit Standard Answer"
icon={<ClipboardIcon size={32} />}
/>
<StandardAnswerCreationForm
standardAnswerCategories={standardAnswerCategories}
existingStandardAnswer={standardAnswer}
/>
</>
return (
<StandardAnswerCreationForm
standardAnswerCategories={standardAnswerCategories}
existingStandardAnswer={standardAnswer}
/>
);
}
export default Page;
export default async function Page(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title="Edit Standard Answer"
backButton
separator
/>
<SettingsLayouts.Body>
<Main id={params.id} />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,11 +1,12 @@
import { AdminPageTitle } from "@/components/admin/Title";
import { StandardAnswerCreationForm } from "@/app/ee/admin/standard-answer/StandardAnswerCreationForm";
import { fetchSS } from "@/lib/utilsSS";
import { ErrorCallout } from "@/components/ErrorCallout";
import BackButton from "@/refresh-components/buttons/BackButton";
import { ClipboardIcon } from "@/components/icons/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import { StandardAnswerCategory } from "@/lib/types";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.STANDARD_ANSWERS]!;
async function Page() {
const standardAnswerCategoriesResponse = await fetchSS(
"/manage/admin/standard-answer/category"
@@ -23,17 +24,19 @@ async function Page() {
(await standardAnswerCategoriesResponse.json()) as StandardAnswerCategory[];
return (
<>
<BackButton />
<AdminPageTitle
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title="New Standard Answer"
icon={<ClipboardIcon size={32} />}
backButton
separator
/>
<StandardAnswerCreationForm
standardAnswerCategories={standardAnswerCategories}
/>
</>
<SettingsLayouts.Body>
<StandardAnswerCreationForm
standardAnswerCategories={standardAnswerCategories}
/>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,7 +1,6 @@
"use client";
import { AdminPageTitle } from "@/components/admin/Title";
import { ClipboardIcon, EditIcon } from "@/components/icons/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { toast } from "@/hooks/useToast";
import { useStandardAnswers, useStandardAnswerCategories } from "./hooks";
import { ThreeDotsLoader } from "@/components/Loading";
@@ -29,10 +28,13 @@ import { PageSelector } from "@/components/PageSelector";
import Text from "@/components/ui/text";
import { TableHeader } from "@/components/ui/table";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { SvgTrash } from "@opal/icons";
import { SvgEdit, SvgTrash } from "@opal/icons";
import { Button } from "@opal/components";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
const NUM_RESULTS_PER_PAGE = 10;
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.STANDARD_ANSWERS]!;
type Displayable = JSX.Element | string;
const RowTemplate = ({
@@ -113,7 +115,7 @@ const StandardAnswersTableRow = ({
key={`edit-${standardAnswer.id}`}
href={`/ee/admin/standard-answer/${standardAnswer.id}` as Route}
>
<EditIcon />
<SvgEdit size={16} />
</Link>,
<div key={`categories-${standardAnswer.id}`}>
{standardAnswer.categories.map((category) => (
@@ -344,7 +346,7 @@ const StandardAnswersTable = ({
);
};
const Main = () => {
function Main() {
const {
data: standardAnswers,
error: standardAnswersError,
@@ -413,18 +415,15 @@ const Main = () => {
</div>
</div>
);
};
}
const Page = () => {
export default function Page() {
return (
<>
<AdminPageTitle
icon={<ClipboardIcon size={32} />}
title="Standard Answers"
/>
<Main />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
export default Page;
}

View File

@@ -1,7 +1,7 @@
"use client";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SvgPaintBrush } from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import Button from "@/refresh-components/buttons/Button";
import {
AppearanceThemeSettings,
@@ -15,6 +15,8 @@ import * as Yup from "yup";
import { EnterpriseSettings } from "@/interfaces/settings";
import { useRouter } from "next/navigation";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.THEME]!;
const CHAR_LIMITS = {
application_name: 50,
custom_greeting_message: 50,
@@ -211,9 +213,9 @@ export default function ThemePage() {
<Form className="w-full h-full">
<SettingsLayouts.Root>
<SettingsLayouts.Header
title="Appearance & Theming"
title={route.title}
description="Customize how the application appears to users across your organization."
icon={SvgPaintBrush}
icon={route.icon}
rightChildren={
<Button
type="button"

View File

@@ -6,6 +6,7 @@ import { useSettingsContext } from "@/providers/SettingsProvider";
import { ApplicationStatus } from "@/interfaces/settings";
import Button from "@/refresh-components/buttons/Button";
import { cn } from "@/lib/utils";
import { ADMIN_PATHS } from "@/lib/admin-routes";
export interface ClientLayoutProps {
children: React.ReactNode;
@@ -18,16 +19,32 @@ export interface ClientLayoutProps {
// the `py-10 px-4 md:px-12` padding below can be removed entirely and
// this prefix list can be deleted.
const SETTINGS_LAYOUT_PREFIXES = [
"/admin/configuration/chat-preferences",
"/admin/configuration/image-generation",
"/admin/configuration/web-search",
"/admin/actions/mcp",
"/admin/actions/open-api",
"/admin/billing",
"/admin/document-index-migration",
"/admin/discord-bot",
"/admin/theme",
"/admin/configuration/llm",
ADMIN_PATHS.CHAT_PREFERENCES,
ADMIN_PATHS.IMAGE_GENERATION,
ADMIN_PATHS.WEB_SEARCH,
ADMIN_PATHS.MCP_ACTIONS,
ADMIN_PATHS.OPENAPI_ACTIONS,
ADMIN_PATHS.BILLING,
ADMIN_PATHS.INDEX_MIGRATION,
ADMIN_PATHS.DISCORD_BOTS,
ADMIN_PATHS.THEME,
ADMIN_PATHS.LLM_MODELS,
ADMIN_PATHS.AGENTS,
ADMIN_PATHS.USERS,
ADMIN_PATHS.TOKEN_RATE_LIMITS,
ADMIN_PATHS.SEARCH_SETTINGS,
ADMIN_PATHS.DOCUMENT_PROCESSING,
ADMIN_PATHS.CODE_INTERPRETER,
ADMIN_PATHS.API_KEYS,
ADMIN_PATHS.ADD_CONNECTOR,
ADMIN_PATHS.INDEXING_STATUS,
ADMIN_PATHS.DOCUMENTS,
ADMIN_PATHS.DEBUG,
ADMIN_PATHS.KNOWLEDGE_GRAPH,
ADMIN_PATHS.SLACK_BOTS,
ADMIN_PATHS.STANDARD_ANSWERS,
ADMIN_PATHS.GROUPS,
ADMIN_PATHS.PERFORMANCE,
];
export function ClientLayout({

245
web/src/lib/admin-routes.ts Normal file
View File

@@ -0,0 +1,245 @@
import { IconFunctionComponent } from "@opal/types";
import {
SvgActions,
SvgActivity,
SvgArrowExchange,
SvgBarChart,
SvgBookOpen,
SvgBubbleText,
SvgClipboard,
SvgCpu,
SvgDiscordMono,
SvgDownload,
SvgFileText,
SvgFolder,
SvgGlobe,
SvgImage,
SvgKey,
SvgMcp,
SvgNetworkGraph,
SvgOnyxOctagon,
SvgPaintBrush,
SvgSearch,
SvgServer,
SvgShield,
SvgSlack,
SvgTerminal,
SvgThumbsUp,
SvgUploadCloud,
SvgUser,
SvgUsers,
SvgWallet,
SvgZoomIn,
} from "@opal/icons";
/**
* Canonical path constants for every admin route.
*/
export const ADMIN_PATHS = {
INDEXING_STATUS: "/admin/indexing/status",
ADD_CONNECTOR: "/admin/add-connector",
DOCUMENT_SETS: "/admin/documents/sets",
DOCUMENT_EXPLORER: "/admin/documents/explorer",
DOCUMENT_FEEDBACK: "/admin/documents/feedback",
AGENTS: "/admin/agents",
SLACK_BOTS: "/admin/bots",
DISCORD_BOTS: "/admin/discord-bot",
MCP_ACTIONS: "/admin/actions/mcp",
OPENAPI_ACTIONS: "/admin/actions/open-api",
STANDARD_ANSWERS: "/admin/standard-answer",
GROUPS: "/admin/groups",
CHAT_PREFERENCES: "/admin/configuration/chat-preferences",
LLM_MODELS: "/admin/configuration/llm",
WEB_SEARCH: "/admin/configuration/web-search",
IMAGE_GENERATION: "/admin/configuration/image-generation",
CODE_INTERPRETER: "/admin/configuration/code-interpreter",
SEARCH_SETTINGS: "/admin/configuration/search",
DOCUMENT_PROCESSING: "/admin/configuration/document-processing",
KNOWLEDGE_GRAPH: "/admin/kg",
USERS: "/admin/users",
API_KEYS: "/admin/api-key",
TOKEN_RATE_LIMITS: "/admin/token-rate-limits",
USAGE: "/admin/performance/usage",
QUERY_HISTORY: "/admin/performance/query-history",
CUSTOM_ANALYTICS: "/admin/performance/custom-analytics",
THEME: "/admin/theme",
BILLING: "/admin/billing",
INDEX_MIGRATION: "/admin/document-index-migration",
DEBUG: "/admin/debug",
// Prefix-only entries (used in SETTINGS_LAYOUT_PREFIXES but have no
// single page header of their own)
DOCUMENTS: "/admin/documents",
PERFORMANCE: "/admin/performance",
} as const;
interface AdminRouteConfig {
icon: IconFunctionComponent;
title: string;
sidebarLabel: string;
}
/**
* Single source of truth for icon, page-header title, and sidebar label
* for every admin route. Keyed by path from `ADMIN_PATHS`.
*/
export const ADMIN_ROUTE_CONFIG: Record<string, AdminRouteConfig> = {
[ADMIN_PATHS.INDEXING_STATUS]: {
icon: SvgBookOpen,
title: "Existing Connectors",
sidebarLabel: "Existing Connectors",
},
[ADMIN_PATHS.ADD_CONNECTOR]: {
icon: SvgUploadCloud,
title: "Add Connector",
sidebarLabel: "Add Connector",
},
[ADMIN_PATHS.DOCUMENT_SETS]: {
icon: SvgFolder,
title: "Document Sets",
sidebarLabel: "Document Sets",
},
[ADMIN_PATHS.DOCUMENT_EXPLORER]: {
icon: SvgZoomIn,
title: "Document Explorer",
sidebarLabel: "Explorer",
},
[ADMIN_PATHS.DOCUMENT_FEEDBACK]: {
icon: SvgThumbsUp,
title: "Document Feedback",
sidebarLabel: "Feedback",
},
[ADMIN_PATHS.AGENTS]: {
icon: SvgOnyxOctagon,
title: "Agents",
sidebarLabel: "Agents",
},
[ADMIN_PATHS.SLACK_BOTS]: {
icon: SvgSlack,
title: "Slack Bots",
sidebarLabel: "Slack Bots",
},
[ADMIN_PATHS.DISCORD_BOTS]: {
icon: SvgDiscordMono,
title: "Discord Bots",
sidebarLabel: "Discord Bots",
},
[ADMIN_PATHS.MCP_ACTIONS]: {
icon: SvgMcp,
title: "MCP Actions",
sidebarLabel: "MCP Actions",
},
[ADMIN_PATHS.OPENAPI_ACTIONS]: {
icon: SvgActions,
title: "OpenAPI Actions",
sidebarLabel: "OpenAPI Actions",
},
[ADMIN_PATHS.STANDARD_ANSWERS]: {
icon: SvgClipboard,
title: "Standard Answers",
sidebarLabel: "Standard Answers",
},
[ADMIN_PATHS.GROUPS]: {
icon: SvgUsers,
title: "Manage User Groups",
sidebarLabel: "Groups",
},
[ADMIN_PATHS.CHAT_PREFERENCES]: {
icon: SvgBubbleText,
title: "Chat Preferences",
sidebarLabel: "Chat Preferences",
},
[ADMIN_PATHS.LLM_MODELS]: {
icon: SvgCpu,
title: "LLM Models",
sidebarLabel: "LLM Models",
},
[ADMIN_PATHS.WEB_SEARCH]: {
icon: SvgGlobe,
title: "Web Search",
sidebarLabel: "Web Search",
},
[ADMIN_PATHS.IMAGE_GENERATION]: {
icon: SvgImage,
title: "Image Generation",
sidebarLabel: "Image Generation",
},
[ADMIN_PATHS.CODE_INTERPRETER]: {
icon: SvgTerminal,
title: "Code Interpreter",
sidebarLabel: "Code Interpreter",
},
[ADMIN_PATHS.SEARCH_SETTINGS]: {
icon: SvgSearch,
title: "Search Settings",
sidebarLabel: "Search Settings",
},
[ADMIN_PATHS.DOCUMENT_PROCESSING]: {
icon: SvgFileText,
title: "Document Processing",
sidebarLabel: "Document Processing",
},
[ADMIN_PATHS.KNOWLEDGE_GRAPH]: {
icon: SvgNetworkGraph,
title: "Knowledge Graph",
sidebarLabel: "Knowledge Graph",
},
[ADMIN_PATHS.USERS]: {
icon: SvgUser,
title: "Manage Users",
sidebarLabel: "Users",
},
[ADMIN_PATHS.API_KEYS]: {
icon: SvgKey,
title: "API Keys",
sidebarLabel: "API Keys",
},
[ADMIN_PATHS.TOKEN_RATE_LIMITS]: {
icon: SvgShield,
title: "Token Rate Limits",
sidebarLabel: "Token Rate Limits",
},
[ADMIN_PATHS.USAGE]: {
icon: SvgActivity,
title: "Usage Statistics",
sidebarLabel: "Usage Statistics",
},
[ADMIN_PATHS.QUERY_HISTORY]: {
icon: SvgServer,
title: "Query History",
sidebarLabel: "Query History",
},
[ADMIN_PATHS.CUSTOM_ANALYTICS]: {
icon: SvgBarChart,
title: "Custom Analytics",
sidebarLabel: "Custom Analytics",
},
[ADMIN_PATHS.THEME]: {
icon: SvgPaintBrush,
title: "Appearance & Theming",
sidebarLabel: "Appearance & Theming",
},
[ADMIN_PATHS.BILLING]: {
icon: SvgWallet,
title: "Plans & Billing",
sidebarLabel: "Plans & Billing",
},
[ADMIN_PATHS.INDEX_MIGRATION]: {
icon: SvgArrowExchange,
title: "Document Index Migration",
sidebarLabel: "Document Index Migration",
},
[ADMIN_PATHS.DEBUG]: {
icon: SvgDownload,
title: "Debug Logs",
sidebarLabel: "Debug Logs",
},
};
/**
* Helper that converts a route config entry into the `{ name, icon, link }`
* shape expected by the sidebar. Extra fields (e.g. `error`) can be spread in.
*/
export function sidebarItem(path: string) {
const config = ADMIN_ROUTE_CONFIG[path]!;
return { name: config.sidebarLabel, icon: config.icon, link: path };
}

View File

@@ -1,7 +1,8 @@
import { VisionProvider } from "@/interfaces/llm";
import { LLMProviderResponse, VisionProvider } from "@/interfaces/llm";
import { LLM_ADMIN_URL } from "@/lib/llmConfig/constants";
export async function fetchVisionProviders(): Promise<VisionProvider[]> {
const response = await fetch("/api/admin/llm/vision-providers", {
const response = await fetch(`${LLM_ADMIN_URL}/vision-providers`, {
headers: {
"Content-Type": "application/json",
},
@@ -11,24 +12,24 @@ export async function fetchVisionProviders(): Promise<VisionProvider[]> {
`Failed to fetch vision providers: ${await response.text()}`
);
}
return response.json();
const data = (await response.json()) as LLMProviderResponse<VisionProvider>;
return data.providers;
}
export async function setDefaultVisionProvider(
providerId: number,
visionModel: string
): Promise<void> {
const response = await fetch(
`/api/admin/llm/provider/${providerId}/default-vision?vision_model=${encodeURIComponent(
visionModel
)}`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
}
);
const response = await fetch(`${LLM_ADMIN_URL}/default-vision`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider_id: providerId,
model_name: visionModel,
}),
});
if (!response.ok) {
const errorMsg = await response.text();

View File

@@ -6,7 +6,10 @@ import {
ModelConfiguration,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
import {
LLM_ADMIN_URL,
LLM_PROVIDERS_ADMIN_URL,
} from "@/lib/llmConfig/constants";
import { OnboardingActions, OnboardingState } from "../types";
import { APIFormFieldState } from "@/refresh-components/form/types";
import {
@@ -225,13 +228,33 @@ export function OnboardingFormWrapper<T extends Record<string, any>>({
try {
const newLlmProvider = await response.json();
if (newLlmProvider?.id != null) {
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{ method: "POST" }
);
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
console.error("Failed to set provider as default", err?.detail);
const defaultModelName =
(payload as Record<string, any>).default_model_name ??
(payload as Record<string, any>).model_configurations?.[0]?.name ??
"";
if (!defaultModelName) {
console.error(
"No model name available to set as default — skipping set-default call"
);
} else {
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: defaultModelName,
}),
});
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
setErrorMessage(
err?.detail ?? "Failed to set provider as default"
);
setApiStatus("error");
setIsSubmitting(false);
return;
}
}
}
} catch (_e) {

View File

@@ -18,13 +18,13 @@ import InputTextAreaField from "@/refresh-components/form/InputTextAreaField";
import InputSelectField from "@/refresh-components/form/InputSelectField";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import {
SvgBubbleText,
SvgAddLines,
SvgActions,
SvgExpand,
SvgFold,
SvgExternalLink,
} from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import { Content } from "@opal/layouts";
import { useSettingsContext } from "@/providers/SettingsProvider";
import useCCPairs from "@/hooks/useCCPairs";
@@ -55,6 +55,8 @@ import useFilter from "@/hooks/useFilter";
import { MCPServer } from "@/lib/tools/interfaces";
import type { IconProps } from "@opal/types";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.CHAT_PREFERENCES]!;
interface DefaultAgentConfiguration {
tool_ids: number[];
system_prompt: string | null;
@@ -323,8 +325,8 @@ function ChatPreferencesForm() {
<>
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgBubbleText}
title="Chat Preferences"
icon={route.icon}
title={route.title}
description="Organization-wide chat settings and defaults. Users can override some of these in their personal settings."
separator
/>

View File

@@ -11,6 +11,7 @@ import {
SvgUnplug,
SvgXOctagon,
} from "@opal/icons";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import { Section } from "@/layouts/general-layouts";
import { Button } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
@@ -21,6 +22,8 @@ import { updateCodeInterpreter } from "@/lib/admin/code-interpreter/svc";
import { ContentAction } from "@opal/layouts";
import { toast } from "@/hooks/useToast";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.CODE_INTERPRETER]!;
interface CodeInterpreterCardProps {
variant?: CardProps["variant"];
title: string;
@@ -161,8 +164,8 @@ export default function CodeInterpreterPage() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgTerminal}
title="Code Interpreter"
icon={route.icon}
title={route.title}
description="Safe and sandboxed Python runtime available to your LLM. See docs for more details."
separator
/>

View File

@@ -11,8 +11,9 @@ import { ThreeDotsLoader } from "@/components/Loading";
import { Content, ContentAction } from "@opal/layouts";
import { Button } from "@opal/components";
import { Hoverable } from "@opal/core";
import { SvgCpu, SvgArrowExchange, SvgSettings, SvgTrash } from "@opal/icons";
import { SvgArrowExchange, SvgSettings, SvgTrash } from "@opal/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTE_CONFIG, ADMIN_PATHS } from "@/lib/admin-routes";
import * as GeneralLayouts from "@/layouts/general-layouts";
import {
getProviderDisplayName,
@@ -44,6 +45,8 @@ import { OpenRouterModal } from "@/sections/modals/llmConfig/OpenRouterModal";
import { CustomModal } from "@/sections/modals/llmConfig/CustomModal";
import { Section } from "@/layouts/general-layouts";
const route = ADMIN_ROUTE_CONFIG[ADMIN_PATHS.LLM_MODELS]!;
// ============================================================================
// Provider form mapping (keyed by provider name from the API)
// ============================================================================
@@ -344,7 +347,7 @@ export default function LLMConfigurationPage() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={SvgCpu} title="LLM Models" separator />
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
{hasProviders ? (

View File

@@ -412,7 +412,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
}),
} as Response);
// Mock POST /api/admin/llm/provider/5/default
// Mock POST /api/admin/llm/default
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
@@ -431,14 +431,20 @@ describe("Custom LLM Provider Configuration Workflow", () => {
const submitButton = screen.getByRole("button", { name: /enable/i });
await user.click(submitButton);
// Verify set as default API was called
// Verify set as default API was called with correct endpoint and body
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/provider/5/default",
expect.objectContaining({
method: "POST",
})
const defaultCall = fetchSpy.mock.calls.find(
([url]) => url === "/api/admin/llm/default"
);
expect(defaultCall).toBeDefined();
const [, options] = defaultCall!;
expect(options.method).toBe("POST");
expect(options.headers).toEqual({ "Content-Type": "application/json" });
const body = JSON.parse(options.body);
expect(body.provider_id).toBe(5);
expect(body).toHaveProperty("model_name");
});
});

View File

@@ -16,119 +16,41 @@ import {
hasActiveSubscription,
} from "@/lib/billing";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import {
ClipboardIcon,
NotebookIconSkeleton,
SlackIconSkeleton,
BrainIcon,
} from "@/components/icons/icons";
import { CombinedSettings } from "@/interfaces/settings";
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
import SidebarBody from "@/sections/sidebar/SidebarBody";
import {
SvgActions,
SvgActivity,
SvgArrowUpCircle,
SvgBarChart,
SvgBubbleText,
SvgCpu,
SvgFileText,
SvgFolder,
SvgGlobe,
SvgArrowExchange,
SvgImage,
SvgKey,
SvgOnyxOctagon,
SvgSearch,
SvgServer,
SvgShield,
SvgThumbsUp,
SvgUploadCloud,
SvgUser,
SvgUsers,
SvgZoomIn,
SvgPaintBrush,
SvgDiscordMono,
SvgWallet,
SvgTerminal,
} from "@opal/icons";
import SvgMcp from "@opal/icons/mcp";
import { SvgArrowUpCircle } from "@opal/icons";
import { ADMIN_PATHS, sidebarItem } from "@/lib/admin-routes";
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
const connectors_items = () => [
{
name: "Existing Connectors",
icon: NotebookIconSkeleton,
link: "/admin/indexing/status",
},
{
name: "Add Connector",
icon: SvgUploadCloud,
link: "/admin/add-connector",
},
sidebarItem(ADMIN_PATHS.INDEXING_STATUS),
sidebarItem(ADMIN_PATHS.ADD_CONNECTOR),
];
const document_management_items = () => [
{
name: "Document Sets",
icon: SvgFolder,
link: "/admin/documents/sets",
},
{
name: "Explorer",
icon: SvgZoomIn,
link: "/admin/documents/explorer",
},
{
name: "Feedback",
icon: SvgThumbsUp,
link: "/admin/documents/feedback",
},
sidebarItem(ADMIN_PATHS.DOCUMENT_SETS),
sidebarItem(ADMIN_PATHS.DOCUMENT_EXPLORER),
sidebarItem(ADMIN_PATHS.DOCUMENT_FEEDBACK),
];
const custom_agents_items = (isCurator: boolean, enableEnterprise: boolean) => {
const items = [
{
name: "Agents",
icon: SvgOnyxOctagon,
link: "/admin/agents",
},
];
const items = [sidebarItem(ADMIN_PATHS.AGENTS)];
if (!isCurator) {
items.push(
{
name: "Slack Bots",
icon: SlackIconSkeleton,
link: "/admin/bots",
},
{
name: "Discord Bots",
icon: SvgDiscordMono,
link: "/admin/discord-bot",
}
sidebarItem(ADMIN_PATHS.SLACK_BOTS),
sidebarItem(ADMIN_PATHS.DISCORD_BOTS)
);
}
items.push(
{
name: "MCP Actions",
icon: SvgMcp,
link: "/admin/actions/mcp",
},
{
name: "OpenAPI Actions",
icon: SvgActions,
link: "/admin/actions/open-api",
}
sidebarItem(ADMIN_PATHS.MCP_ACTIONS),
sidebarItem(ADMIN_PATHS.OPENAPI_ACTIONS)
);
if (enableEnterprise) {
items.push({
name: "Standard Answers",
icon: ClipboardIcon,
link: "/admin/standard-answer",
});
items.push(sidebarItem(ADMIN_PATHS.STANDARD_ANSWERS));
}
return items;
@@ -170,13 +92,7 @@ const collections = (
? [
{
name: "User Management",
items: [
{
name: "Groups",
icon: SvgUsers,
link: "/admin/groups",
},
],
items: [sidebarItem(ADMIN_PATHS.GROUPS)],
},
]
: []),
@@ -185,84 +101,30 @@ const collections = (
{
name: "Configuration",
items: [
{
name: "Chat Preferences",
icon: SvgBubbleText,
link: "/admin/configuration/chat-preferences",
},
{
name: "LLM Models",
icon: SvgCpu,
link: "/admin/configuration/llm",
},
{
name: "Web Search",
icon: SvgGlobe,
link: "/admin/configuration/web-search",
},
{
name: "Image Generation",
icon: SvgImage,
link: "/admin/configuration/image-generation",
},
{
name: "Code Interpreter",
icon: SvgTerminal,
link: "/admin/configuration/code-interpreter",
},
sidebarItem(ADMIN_PATHS.CHAT_PREFERENCES),
sidebarItem(ADMIN_PATHS.LLM_MODELS),
sidebarItem(ADMIN_PATHS.WEB_SEARCH),
sidebarItem(ADMIN_PATHS.IMAGE_GENERATION),
sidebarItem(ADMIN_PATHS.CODE_INTERPRETER),
...(!enableCloud && vectorDbEnabled
? [
{
...sidebarItem(ADMIN_PATHS.SEARCH_SETTINGS),
error: settings?.settings.needs_reindexing,
name: "Search Settings",
icon: SvgSearch,
link: "/admin/configuration/search",
},
]
: []),
{
name: "Document Processing",
icon: SvgFileText,
link: "/admin/configuration/document-processing",
},
...(kgExposed
? [
{
name: "Knowledge Graph",
icon: BrainIcon,
link: "/admin/kg",
},
]
: []),
sidebarItem(ADMIN_PATHS.DOCUMENT_PROCESSING),
...(kgExposed ? [sidebarItem(ADMIN_PATHS.KNOWLEDGE_GRAPH)] : []),
],
},
{
name: "User Management",
items: [
{
name: "Users",
icon: SvgUser,
link: "/admin/users",
},
...(enableEnterprise
? [
{
name: "Groups",
icon: SvgUsers,
link: "/admin/groups",
},
]
: []),
{
name: "API Keys",
icon: SvgKey,
link: "/admin/api-key",
},
{
name: "Token Rate Limits",
icon: SvgShield,
link: "/admin/token-rate-limits",
},
sidebarItem(ADMIN_PATHS.USERS),
...(enableEnterprise ? [sidebarItem(ADMIN_PATHS.GROUPS)] : []),
sidebarItem(ADMIN_PATHS.API_KEYS),
sidebarItem(ADMIN_PATHS.TOKEN_RATE_LIMITS),
],
},
...(enableEnterprise
@@ -270,28 +132,12 @@ const collections = (
{
name: "Performance",
items: [
{
name: "Usage Statistics",
icon: SvgActivity,
link: "/admin/performance/usage",
},
sidebarItem(ADMIN_PATHS.USAGE),
...(settings?.settings.query_history_type !== "disabled"
? [
{
name: "Query History",
icon: SvgServer,
link: "/admin/performance/query-history",
},
]
? [sidebarItem(ADMIN_PATHS.QUERY_HISTORY)]
: []),
...(!enableCloud && customAnalyticsEnabled
? [
{
name: "Custom Analytics",
icon: SvgBarChart,
link: "/admin/performance/custom-analytics",
},
]
? [sidebarItem(ADMIN_PATHS.CUSTOM_ANALYTICS)]
: []),
],
},
@@ -300,29 +146,16 @@ const collections = (
{
name: "Settings",
items: [
...(enableEnterprise
? [
{
name: "Appearance & Theming",
icon: SvgPaintBrush,
link: "/admin/theme",
},
]
: []),
...(enableEnterprise ? [sidebarItem(ADMIN_PATHS.THEME)] : []),
// Always show billing/upgrade - community users need access to upgrade
{
name: hasSubscription ? "Plans & Billing" : "Upgrade Plan",
icon: hasSubscription ? SvgWallet : SvgArrowUpCircle,
link: "/admin/billing",
...sidebarItem(ADMIN_PATHS.BILLING),
...(hasSubscription
? {}
: { name: "Upgrade Plan", icon: SvgArrowUpCircle }),
},
...(settings?.settings.opensearch_indexing_enabled
? [
{
name: "Document Index Migration",
icon: SvgArrowExchange,
link: "/admin/document-index-migration",
},
]
? [sidebarItem(ADMIN_PATHS.INDEX_MIGRATION)]
: []),
],
},

View File

@@ -20,7 +20,7 @@ async function deleteAllProviders(client: OnyxApiClient): Promise<void> {
const providers = await client.listLlmProviders();
for (const provider of providers) {
try {
await client.deleteProvider(provider.id);
await client.deleteProvider(provider.id, { force: true });
} catch (error) {
console.warn(
`Failed to delete provider ${provider.id}: ${String(error)}`

View File

@@ -526,8 +526,14 @@ export class OnyxApiClient {
*
* @param providerId - The provider ID to delete
*/
async deleteProvider(providerId: number): Promise<void> {
const response = await this.delete(`/admin/llm/provider/${providerId}`);
async deleteProvider(
providerId: number,
{ force = false }: { force?: boolean } = {}
): Promise<void> {
const query = force ? "?force=true" : "";
const response = await this.delete(
`/admin/llm/provider/${providerId}${query}`
);
await this.handleResponseSoft(
response,