Compare commits

..

3 Commits

Author SHA1 Message Date
Dane Urban
fdccc04b4a . 2026-03-02 22:16:25 -08:00
Dane Urban
b862cc4bd7 Add integration test 2026-03-02 17:51:51 -08:00
Dane Urban
e37382e40c Block deleting default provider 2026-03-02 17:33:31 -08:00
31 changed files with 374 additions and 1179 deletions

View File

@@ -1,37 +0,0 @@
"""add cache_store table
Revision ID: 2664261bfaab
Revises: 4a1e4b1c89d2
Create Date: 2026-02-27 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2664261bfaab"
down_revision = "4a1e4b1c89d2"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"cache_store",
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.LargeBinary(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("key"),
)
op.create_index(
"ix_cache_store_expires",
"cache_store",
["expires_at"],
postgresql_where=sa.text("expires_at IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index("ix_cache_store_expires", table_name="cache_store")
op.drop_table("cache_store")

View File

@@ -1,34 +0,0 @@
"""make scim_user_mapping.external_id nullable
Revision ID: a3b8d9e2f1c4
Revises: 2664261bfaab
Create Date: 2026-03-02
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "a3b8d9e2f1c4"
down_revision = "2664261bfaab"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=True,
)
def downgrade() -> None:
# Delete any rows where external_id is NULL before re-applying NOT NULL
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
op.alter_column(
"scim_user_mapping",
"external_id",
nullable=False,
)

View File

@@ -126,16 +126,12 @@ class ScimDAL(DAL):
def create_user_mapping(
self,
external_id: str | None,
external_id: str,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a SCIM mapping for a user.
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
allows this). The mapping still marks the user as SCIM-managed.
"""
"""Create a mapping between a SCIM externalId and an Onyx user."""
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
@@ -274,13 +270,8 @@ class ScimDAL(DAL):
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
# unless they were explicitly linked via SCIM provisioning.
query = (
select(User)
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
)
if scim_filter:
@@ -330,37 +321,34 @@ class ScimDAL(DAL):
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Sync the SCIM mapping for a user.
If a mapping already exists, its fields are updated (including
setting ``external_id`` to ``None`` when the IdP omits it).
If no mapping exists and ``new_external_id`` is provided, a new
mapping is created. A mapping is never deleted here — SCIM-managed
users must retain their mapping to remain visible in ``GET /Users``.
"""Create, update, or delete the external ID mapping for a user.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
mapping = self.get_user_mapping_by_user_id(user_id)
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
elif new_external_id:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_mappings_batch(
self, user_ids: list[UUID]

View File

@@ -423,63 +423,15 @@ def create_user(
email = user_resource.userName.strip()
# Check for existing user — if they exist but aren't SCIM-managed yet,
# link them to the IdP rather than rejecting with 409.
external_id: str | None = user_resource.externalId
scim_username: str = user_resource.userName.strip()
fields: ScimMappingFields = _fields_from_resource(user_resource)
existing_user = dal.get_user_by_email(email)
if existing_user:
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
if existing_mapping:
return _scim_error_response(409, f"User with email {email} already exists")
# Adopt pre-existing user into SCIM management.
# Reactivating a deactivated user consumes a seat, so enforce the
# seat limit the same way replace_user does.
if user_resource.active and not existing_user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
existing_user,
is_active=user_resource.active,
**({"personal_name": personal_name} if personal_name else {}),
)
try:
dal.create_user_mapping(
external_id=external_id,
user_id=existing_user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
return _scim_resource_response(
provider.build_user_resource(
existing_user,
external_id,
scim_username=scim_username,
fields=fields,
),
status_code=201,
)
# Only enforce seat limit for net-new users — adopting a pre-existing
# user doesn't consume a new seat.
# Enforce seat limit
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
@@ -497,21 +449,21 @@ def create_user(
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Always create a SCIM mapping so that the user is marked as
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
try:
# Create SCIM mapping when externalId is provided — this is how the IdP
# correlates this user on subsequent requests. Per RFC 7643, externalId
# is optional and assigned by the provisioning client.
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
if external_id:
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(

View File

@@ -170,10 +170,7 @@ class ScimProvider(ABC):
formatted=user.personal_name or "",
)
if not user.personal_name:
# Derive a reasonable name from the email so that SCIM spec tests
# see non-empty givenName / familyName for every user resource.
local = user.email.split("@")[0] if user.email else ""
return ScimName(givenName=local, familyName="", formatted=local)
return ScimName(givenName="", familyName="", formatted="")
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],

View File

@@ -59,12 +59,6 @@ def _run_auto_llm_update() -> None:
sync_llm_models_from_github(db_session)
def _run_cache_cleanup() -> None:
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
cleanup_expired_cache_entries()
def _run_scheduled_eval() -> None:
from onyx.configs.app_configs import BRAINTRUST_API_KEY
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
@@ -106,26 +100,12 @@ def _run_scheduled_eval() -> None:
)
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
tasks: list[_PeriodicTaskDef] = []
if CACHE_BACKEND == CacheBackendType.POSTGRES:
tasks.append(
_PeriodicTaskDef(
name="cache-cleanup",
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
run_fn=_run_cache_cleanup,
)
)
if AUTO_LLM_CONFIG_URL:
tasks.append(
_PeriodicTaskDef(

View File

@@ -75,41 +75,31 @@ def _claim_next_processing_file(db_session: Session) -> UUID | None:
return file_id
def _claim_next_deleting_file(
db_session: Session,
exclude_ids: set[UUID] | None = None,
) -> UUID | None:
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
"""Claim the next DELETING file.
No status transition needed — the impl deletes the row on success.
The short-lived FOR UPDATE lock prevents concurrent claims.
*exclude_ids* prevents re-processing the same file if the impl fails.
"""
stmt = (
file_id = db_session.execute(
select(UserFile.id)
.where(UserFile.status == UserFileStatus.DELETING)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
if exclude_ids:
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
file_id = db_session.execute(stmt).scalar_one_or_none()
).scalar_one_or_none()
# Commit to release the row lock promptly.
db_session.commit()
return file_id
def _claim_next_sync_file(
db_session: Session,
exclude_ids: set[UUID] | None = None,
) -> UUID | None:
def _claim_next_sync_file(db_session: Session) -> UUID | None:
"""Claim the next file needing project/persona sync.
No status transition needed — the impl clears the sync flags on
success. The short-lived FOR UPDATE lock prevents concurrent claims.
*exclude_ids* prevents re-processing the same file if the impl fails.
"""
stmt = (
file_id = db_session.execute(
select(UserFile.id)
.where(
sa.and_(
@@ -123,10 +113,7 @@ def _claim_next_sync_file(
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
if exclude_ids:
stmt = stmt.where(UserFile.id.notin_(exclude_ids))
file_id = db_session.execute(stmt).scalar_one_or_none()
).scalar_one_or_none()
db_session.commit()
return file_id
@@ -162,13 +149,11 @@ def drain_delete_loop(tenant_id: str) -> None:
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
seen: set[UUID] = set()
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_deleting_file(session, exclude_ids=seen)
file_id = _claim_next_deleting_file(session)
if file_id is None:
break
seen.add(file_id)
delete_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
@@ -183,13 +168,11 @@ def drain_project_sync_loop(tenant_id: str) -> None:
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
seen: set[UUID] = set()
while True:
with get_session_with_current_tenant() as session:
file_id = _claim_next_sync_file(session, exclude_ids=seen)
file_id = _claim_next_sync_file(session)
if file_id is None:
break
seen.add(file_id)
project_sync_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,

View File

@@ -12,15 +12,9 @@ def _build_redis_backend(tenant_id: str) -> CacheBackend:
return RedisCacheBackend(redis_pool.get_client(tenant_id))
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
from onyx.cache.postgres_backend import PostgresCacheBackend
return PostgresCacheBackend(tenant_id)
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
CacheBackendType.REDIS: _build_redis_backend,
CacheBackendType.POSTGRES: _build_postgres_backend,
# CacheBackendType.POSTGRES will be added in a follow-up PR.
}

View File

@@ -1,9 +1,6 @@
import abc
from enum import Enum
TTL_KEY_NOT_FOUND = -2
TTL_NO_EXPIRY = -1
class CacheBackendType(str, Enum):
REDIS = "redis"
@@ -29,14 +26,6 @@ class CacheLock(abc.ABC):
def owned(self) -> bool:
raise NotImplementedError
def __enter__(self) -> "CacheLock":
if not self.acquire():
raise RuntimeError("Failed to acquire lock")
return self
def __exit__(self, *args: object) -> None:
self.release()
class CacheBackend(abc.ABC):
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
@@ -76,11 +65,7 @@ class CacheBackend(abc.ABC):
@abc.abstractmethod
def ttl(self, key: str) -> int:
"""Return remaining TTL in seconds.
Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry,
``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired.
"""
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
raise NotImplementedError
# -- distributed lock --------------------------------------------------

View File

@@ -1,323 +0,0 @@
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
for distributed locking, and a polling loop for the BLPOP pattern.
"""
import hashlib
import struct
import time
import uuid
from contextlib import AbstractContextManager
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
from onyx.db.models import CacheStore
_LIST_KEY_PREFIX = "_q:"
# ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;)
# captures all list-item keys (e.g. _q:mylist:123:uuid) without including other
# lists whose names share a prefix (e.g. _q:mylist2:...).
_LIST_KEY_RANGE_TERMINATOR = ";"
_LIST_ITEM_TTL_SECONDS = 3600
_LOCK_POLL_INTERVAL = 0.1
_BLPOP_POLL_INTERVAL = 0.25
def _list_item_key(key: str) -> str:
"""Unique key for a list item. Timestamp for FIFO ordering; UUID prevents
collision when concurrent rpush calls occur within the same nanosecond.
"""
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}"
def _to_bytes(value: str | bytes | int | float) -> bytes:
if isinstance(value, bytes):
return value
return str(value).encode()
# ------------------------------------------------------------------
# Lock
# ------------------------------------------------------------------
class PostgresCacheLock(CacheLock):
"""Advisory-lock-based distributed lock.
Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied
to the session's connection; releasing or closing the session frees it.
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
``timeout`` seconds. They are released when ``release()`` is
called or when the session is closed.
"""
def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None:
self._lock_id = lock_id
self._timeout = timeout
self._tenant_id = tenant_id
self._session_cm: AbstractContextManager[Session] | None = None
self._session: Session | None = None
self._acquired = False
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id)
self._session = self._session_cm.__enter__()
try:
if not blocking:
return self._try_lock()
effective_timeout = blocking_timeout or self._timeout
deadline = (
(time.monotonic() + effective_timeout) if effective_timeout else None
)
while True:
if self._try_lock():
return True
if deadline is not None and time.monotonic() >= deadline:
return False
time.sleep(_LOCK_POLL_INTERVAL)
finally:
if not self._acquired:
self._close_session()
def release(self) -> None:
if not self._acquired or self._session is None:
return
try:
self._session.execute(select(func.pg_advisory_unlock(self._lock_id)))
finally:
self._acquired = False
self._close_session()
def owned(self) -> bool:
return self._acquired
def _close_session(self) -> None:
if self._session_cm is not None:
try:
self._session_cm.__exit__(None, None, None)
finally:
self._session_cm = None
self._session = None
def _try_lock(self) -> bool:
assert self._session is not None
result = self._session.execute(
select(func.pg_try_advisory_lock(self._lock_id))
).scalar()
if result:
self._acquired = True
return True
return False
# ------------------------------------------------------------------
# Backend
# ------------------------------------------------------------------
class PostgresCacheBackend(CacheBackend):
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
Each operation opens and closes its own database session so the backend
is safe to share across threads. Tenant isolation is handled by
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
"""
def __init__(self, tenant_id: str) -> None:
self._tenant_id = tenant_id
# -- basic key/value ---------------------------------------------------
def get(self, key: str) -> bytes | None:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.value).where(
CacheStore.key == key,
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
value = session.execute(stmt).scalar_one_or_none()
if value is None:
return None
return bytes(value)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
value_bytes = _to_bytes(value)
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=ex)
if ex is not None
else None
)
stmt = (
pg_insert(CacheStore)
.values(key=key, value=value_bytes, expires_at=expires_at)
.on_conflict_do_update(
index_elements=[CacheStore.key],
set_={"value": value_bytes, "expires_at": expires_at},
)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def delete(self, key: str) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(delete(CacheStore).where(CacheStore.key == key))
session.commit()
def exists(self, key: str) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = (
select(CacheStore.key)
.where(
CacheStore.key == key,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.limit(1)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
return session.execute(stmt).first() is not None
# -- TTL ---------------------------------------------------------------
def expire(self, key: str, seconds: int) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
stmt = (
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def ttl(self, key: str) -> int:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
result = session.execute(stmt).first()
if result is None:
return TTL_KEY_NOT_FOUND
expires_at: datetime | None = result[0]
if expires_at is None:
return TTL_NO_EXPIRY
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
if remaining <= 0:
return TTL_KEY_NOT_FOUND
return int(remaining)
# -- distributed lock --------------------------------------------------
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
return PostgresCacheLock(
self._lock_id_for(name), timeout, tenant_id=self._tenant_id
)
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
def rpush(self, key: str, value: str | bytes) -> None:
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
if timeout <= 0:
raise ValueError(
"PostgresCacheBackend.blpop requires timeout > 0. "
"timeout=0 would block the calling thread indefinitely "
"with no way to interrupt short of process termination."
)
from onyx.db.engine.sql_engine import get_session_with_tenant
deadline = time.monotonic() + timeout
while True:
for key in keys:
lower = f"{_LIST_KEY_PREFIX}{key}:"
upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}"
stmt = (
select(CacheStore)
.where(
CacheStore.key >= lower,
CacheStore.key < upper,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.order_by(CacheStore.key)
.limit(1)
.with_for_update(skip_locked=True)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
row = session.execute(stmt).scalars().first()
if row is not None:
value = bytes(row.value) if row.value else b""
session.delete(row)
session.commit()
return (key.encode(), value)
if time.monotonic() >= deadline:
return None
time.sleep(_BLPOP_POLL_INTERVAL)
# -- helpers -----------------------------------------------------------
def _lock_id_for(self, name: str) -> int:
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
return struct.unpack("q", h[:8])[0]
# ------------------------------------------------------------------
# Periodic cleanup
# ------------------------------------------------------------------
def cleanup_expired_cache_entries() -> None:
"""Delete rows whose ``expires_at`` is in the past.
Called by the periodic poller every 5 minutes.
"""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(
delete(CacheStore).where(
CacheStore.expires_at.is_not(None),
CacheStore.expires_at < func.now(),
)
)
session.commit()

View File

@@ -4926,9 +4926,7 @@ class ScimUserMapping(Base):
__tablename__ = "scim_user_mapping"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_id: Mapped[str | None] = mapped_column(
String, unique=True, index=True, nullable=True
)
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
@@ -4985,25 +4983,3 @@ class CodeInterpreterServer(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
class CacheStore(Base):
"""Key-value cache table used by ``PostgresCacheBackend``.
Replaces Redis for simple KV caching, locks, and list operations
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
Intentionally separate from ``KVStore``:
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
- Holds ephemeral data (tokens, stop signals, lock state) not
persistent application config, so cleanup can be aggressive.
"""
__tablename__ = "cache_store"
key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
expires_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)

View File

@@ -67,18 +67,6 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
usage as a dict with chat completion format instead of keeping it as
ResponseAPIUsage. Our patch creates a deep copy before modification.
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
to check for router calls, but when metadata is explicitly None (key exists with
value None), the default {} is not used
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
the real exception (e.g. AuthenticationError for wrong API key)
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
not iterable
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
against metadata being explicitly None. Triggered when Responses API bridge
passes **litellm_params containing metadata=None.
"""
import time
@@ -737,44 +725,6 @@ def _patch_logging_assembled_streaming_response() -> None:
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
def _patch_responses_metadata_none() -> None:
"""
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
None (the key exists, so the default is not used), causing:
TypeError: argument of type 'NoneType' is not iterable
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
This happens when the Responses API bridge calls litellm.responses() with
**litellm_params which may contain metadata=None.
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
which does not guard against metadata being explicitly None. Same pattern exists
on line 1407 for async path.
"""
import litellm as _litellm
from functools import wraps
original_responses = _litellm.responses
if getattr(original_responses, "_metadata_patched", False):
return
@wraps(original_responses)
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
if kwargs.get("metadata") is None:
kwargs["metadata"] = {}
return original_responses(*args, **kwargs)
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
_litellm.responses = _patched_responses
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -786,7 +736,6 @@ def apply_monkey_patches() -> None:
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
"""
_patch_ollama_chunk_parser()
_patch_openai_responses_parallel_tool_calls()
@@ -794,4 +743,3 @@ def apply_monkey_patches() -> None:
_patch_azure_responses_should_fake_stream()
_patch_responses_api_usage_format()
_patch_logging_assembled_streaming_response()
_patch_responses_metadata_none()

View File

@@ -479,10 +479,20 @@ def put_llm_provider(
@admin_router.delete("/provider/{provider_id}")
def delete_llm_provider(
provider_id: int,
force: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
if not force:
model = fetch_default_llm_model(db_session)
if model and model.llm_provider_id == provider_id:
raise HTTPException(
status_code=400,
detail="Cannot delete the default LLM provider",
)
remove_llm_provider(db_session, provider_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -1,57 +0,0 @@
"""Fixtures for cache backend tests.
Requires a running PostgreSQL instance (and Redis for parity tests).
Run with::
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
"""
from collections.abc import Generator
import pytest
from onyx.cache.interface import CacheBackend
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.cache.redis_backend import RedisCacheBackend
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(scope="session", autouse=True)
def _init_db() -> Generator[None, None, None]:
"""Initialize DB engine. Assumes Postgres has migrations applied (e.g. via docker compose)."""
SqlEngine.init_engine(pool_size=5, max_overflow=2)
yield
@pytest.fixture(autouse=True)
def _tenant_context() -> Generator[None, None, None]:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture
def pg_cache() -> PostgresCacheBackend:
return PostgresCacheBackend(TEST_TENANT_ID)
@pytest.fixture
def redis_cache() -> RedisCacheBackend:
from onyx.redis.redis_pool import redis_pool
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
def cache(
request: pytest.FixtureRequest,
pg_cache: PostgresCacheBackend,
redis_cache: RedisCacheBackend,
) -> CacheBackend:
if request.param == "postgres":
return pg_cache
return redis_cache

View File

@@ -1,100 +0,0 @@
"""Parameterized tests that run the same CacheBackend operations against
both Redis and PostgreSQL, asserting identical return values.
Each test runs twice (once per backend) via the ``cache`` fixture defined
in conftest.py.
"""
import time
from uuid import uuid4
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
def _key() -> str:
return f"parity_{uuid4().hex[:12]}"
class TestKVParity:
def test_get_missing(self, cache: CacheBackend) -> None:
assert cache.get(_key()) is None
def test_get_set(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"value")
assert cache.get(k) == b"value"
def test_overwrite(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"a")
cache.set(k, b"b")
assert cache.get(k) == b"b"
def test_set_string(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, "hello")
assert cache.get(k) == b"hello"
def test_set_int(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, 42)
assert cache.get(k) == b"42"
def test_delete(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x")
cache.delete(k)
assert cache.get(k) is None
def test_exists(self, cache: CacheBackend) -> None:
k = _key()
assert not cache.exists(k)
cache.set(k, b"x")
assert cache.exists(k)
class TestTTLParity:
def test_ttl_missing(self, cache: CacheBackend) -> None:
assert cache.ttl(_key()) == TTL_KEY_NOT_FOUND
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x")
assert cache.ttl(k) == TTL_NO_EXPIRY
def test_ttl_remaining(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x", ex=10)
remaining = cache.ttl(k)
assert 8 <= remaining <= 10
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x", ex=1)
assert cache.get(k) == b"x"
time.sleep(1.5)
assert cache.get(k) is None
class TestLockParity:
def test_acquire_release(self, cache: CacheBackend) -> None:
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
assert lock.acquire(blocking=False)
assert lock.owned()
lock.release()
assert not lock.owned()
class TestListParity:
def test_rpush_blpop(self, cache: CacheBackend) -> None:
k = f"parity_list_{uuid4().hex[:8]}"
cache.rpush(k, b"item")
result = cache.blpop([k], timeout=1)
assert result is not None
assert result[1] == b"item"
def test_blpop_timeout(self, cache: CacheBackend) -> None:
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
assert result is None

View File

@@ -1,229 +0,0 @@
"""Tests for PostgresCacheBackend against real PostgreSQL.
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
locks (acquire / release / contention), list operations (rpush / blpop),
and the periodic cleanup function.
"""
import time
from uuid import uuid4
from sqlalchemy import select
from onyx.cache.interface import TTL_KEY_NOT_FOUND
from onyx.cache.interface import TTL_NO_EXPIRY
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.db.models import CacheStore
def _key() -> str:
return f"test_{uuid4().hex[:12]}"
# ------------------------------------------------------------------
# Basic KV
# ------------------------------------------------------------------
class TestKV:
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"hello")
assert pg_cache.get(k) == b"hello"
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
assert pg_cache.get(_key()) is None
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"first")
pg_cache.set(k, b"second")
assert pg_cache.get(k) == b"second"
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, "string_val")
assert pg_cache.get(k) == b"string_val"
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, 42)
assert pg_cache.get(k) == b"42"
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"to_delete")
pg_cache.delete(k)
assert pg_cache.get(k) is None
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
pg_cache.delete(_key())
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
assert not pg_cache.exists(k)
pg_cache.set(k, b"x")
assert pg_cache.exists(k)
# ------------------------------------------------------------------
# TTL
# ------------------------------------------------------------------
class TestTTL:
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"ephemeral", ex=1)
assert pg_cache.get(k) == b"ephemeral"
time.sleep(1.5)
assert pg_cache.get(k) is None
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"forever")
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
assert pg_cache.ttl(_key()) == TTL_KEY_NOT_FOUND
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=10)
remaining = pg_cache.ttl(k)
assert 8 <= remaining <= 10
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=1)
time.sleep(1.5)
assert pg_cache.ttl(k) == TTL_KEY_NOT_FOUND
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x")
assert pg_cache.ttl(k) == TTL_NO_EXPIRY
pg_cache.expire(k, 10)
assert 8 <= pg_cache.ttl(k) <= 10
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=1)
assert pg_cache.exists(k)
time.sleep(1.5)
assert not pg_cache.exists(k)
# ------------------------------------------------------------------
# Locks
# ------------------------------------------------------------------
class TestLock:
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
assert lock.acquire(blocking=False)
assert lock.owned()
lock.release()
assert not lock.owned()
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
name = f"contention_{uuid4().hex[:8]}"
lock1 = pg_cache.lock(name)
lock2 = pg_cache.lock(name)
assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
assert lock2.acquire(blocking=False)
lock2.release()
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
assert lock.owned()
assert not lock.owned()
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
name = f"timeout_{uuid4().hex[:8]}"
holder = pg_cache.lock(name)
holder.acquire(blocking=False)
waiter = pg_cache.lock(name, timeout=0.3)
start = time.monotonic()
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
elapsed = time.monotonic() - start
assert elapsed >= 0.25
holder.release()
# ------------------------------------------------------------------
# List (rpush / blpop)
# ------------------------------------------------------------------
class TestList:
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
k = f"list_{uuid4().hex[:8]}"
pg_cache.rpush(k, b"item1")
result = pg_cache.blpop([k], timeout=1)
assert result is not None
assert result == (k.encode(), b"item1")
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
assert result is None
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
k = f"fifo_{uuid4().hex[:8]}"
pg_cache.rpush(k, b"first")
time.sleep(0.01)
pg_cache.rpush(k, b"second")
r1 = pg_cache.blpop([k], timeout=1)
r2 = pg_cache.blpop([k], timeout=1)
assert r1 is not None and r1[1] == b"first"
assert r2 is not None and r2[1] == b"second"
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
k1 = f"mk1_{uuid4().hex[:8]}"
k2 = f"mk2_{uuid4().hex[:8]}"
pg_cache.rpush(k2, b"from_k2")
result = pg_cache.blpop([k1, k2], timeout=1)
assert result is not None
assert result == (k2.encode(), b"from_k2")
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
class TestCleanup:
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
from onyx.db.engine.sql_engine import get_session_with_current_tenant
k = _key()
pg_cache.set(k, b"stale", ex=1)
time.sleep(1.5)
cleanup_expired_cache_entries()
stmt = select(CacheStore.key).where(CacheStore.key == k)
with get_session_with_current_tenant() as session:
row = session.execute(stmt).first()
assert row is None, "expired row should be physically deleted"
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"fresh", ex=300)
cleanup_expired_cache_entries()
assert pg_cache.get(k) == b"fresh"
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"permanent")
cleanup_expired_cache_entries()
assert pg_cache.get(k) == b"permanent"

View File

@@ -386,6 +386,261 @@ def test_delete_llm_provider(
assert provider_data is None
def test_delete_default_llm_provider_rejected(reset: None) -> None: # noqa: ARG001
"""Deleting the default LLM provider should return 400."""
admin_user = UserManager.create(name="admin_user")
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "test-provider-default-delete",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000000",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Set this provider as the default
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
headers=admin_user.headers,
json={
"provider_id": created_provider["id"],
"model_name": "gpt-4o-mini",
},
)
assert set_default_response.status_code == 200
# Attempt to delete the default provider — should be rejected
delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
headers=admin_user.headers,
)
assert delete_response.status_code == 400
assert "Cannot delete the default LLM provider" in delete_response.json()["detail"]
# Verify provider still exists
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
def test_delete_non_default_llm_provider_with_default_set(
reset: None, # noqa: ARG001
) -> None:
"""Deleting a non-default provider should succeed even when a default is set."""
admin_user = UserManager.create(name="admin_user")
# Create two providers
response_default = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "default-provider",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000000",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert response_default.status_code == 200
default_provider = response_default.json()
response_other = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "other-provider",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000000",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert response_other.status_code == 200
other_provider = response_other.json()
# Set the first provider as default
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
headers=admin_user.headers,
json={
"provider_id": default_provider["id"],
"model_name": "gpt-4o-mini",
},
)
assert set_default_response.status_code == 200
# Delete the non-default provider — should succeed
delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{other_provider['id']}",
headers=admin_user.headers,
)
assert delete_response.status_code == 200
# Verify the non-default provider is gone
provider_data = _get_provider_by_id(admin_user, other_provider["id"])
assert provider_data is None
# Verify the default provider still exists
default_data = _get_provider_by_id(admin_user, default_provider["id"])
assert default_data is not None
def test_force_delete_default_llm_provider(
reset: None, # noqa: ARG001
) -> None:
"""Force-deleting the default LLM provider should succeed."""
admin_user = UserManager.create(name="admin_user")
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "test-provider-force-delete",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000000",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Set this provider as the default
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/default",
headers=admin_user.headers,
json={
"provider_id": created_provider["id"],
"model_name": "gpt-4o-mini",
},
)
assert set_default_response.status_code == 200
# Attempt to delete without force — should be rejected
delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
headers=admin_user.headers,
)
assert delete_response.status_code == 400
# Force delete — should succeed
force_delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}?force=true",
headers=admin_user.headers,
)
assert force_delete_response.status_code == 200
# Verify provider is gone
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is None
def test_delete_default_vision_provider_clears_vision_default(
reset: None, # noqa: ARG001
) -> None:
"""Deleting the default vision provider should succeed and clear the vision default."""
admin_user = UserManager.create(name="admin_user")
# Create a text provider and set it as default (so we have a default text provider)
text_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "text-provider",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000001",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert text_response.status_code == 200
text_provider = text_response.json()
_set_default_provider(admin_user, text_provider["id"], "gpt-4o-mini")
# Create a vision provider and set it as default vision
vision_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json={
"name": "vision-provider",
"provider": LlmProviderNames.OPENAI,
"api_key": "sk-000000000000000000000000000000000000000000000002",
"model_configurations": [
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
supports_image_input=True,
).model_dump()
],
"is_public": True,
"groups": [],
},
)
assert vision_response.status_code == 200
vision_provider = vision_response.json()
_set_default_vision_provider(admin_user, vision_provider["id"], "gpt-4o")
# Verify vision default is set
data = _get_providers_admin(admin_user)
assert data is not None
_, _, vision_default = _unpack_data(data)
assert vision_default is not None
assert vision_default["provider_id"] == vision_provider["id"]
# Delete the vision provider — should succeed (only text default is protected)
delete_response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{vision_provider['id']}",
headers=admin_user.headers,
)
assert delete_response.status_code == 200
# Verify the vision provider is gone
provider_data = _get_provider_by_id(admin_user, vision_provider["id"])
assert provider_data is None
# Verify there is no default vision provider
data = _get_providers_admin(admin_user)
assert data is not None
_, text_default, vision_default = _unpack_data(data)
assert vision_default is None
# Verify the text default is still intact
assert text_default is not None
assert text_default["provider_id"] == text_provider["id"]
def test_duplicate_provider_name_rejected(reset: None) -> None: # noqa: ARG001
"""Creating a provider with a name that already exists should return 400."""
admin_user = UserManager.create(name="admin_user")

View File

@@ -117,10 +117,7 @@ class TestOktaProvider:
user = _make_mock_user(personal_name=None)
result = provider.build_user_resource(user, None)
# Falls back to deriving name from email local part
assert result.name == ScimName(
givenName="test", familyName="", formatted="test"
)
assert result.name == ScimName(givenName="", familyName="", formatted="")
assert result.displayName is None
def test_build_user_resource_scim_username_preserves_case(self) -> None:

View File

@@ -215,7 +215,7 @@ class TestCreateUser:
mock_dal.commit.assert_called_once()
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_missing_external_id_still_creates_mapping(
def test_missing_external_id_creates_user_without_mapping(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
@@ -223,7 +223,6 @@ class TestCreateUser:
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
"""Mapping is always created to mark user as SCIM-managed."""
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(externalId=None)
@@ -237,11 +236,11 @@ class TestCreateUser:
parsed = parse_scim_user(result, status=201)
assert parsed.userName is not None
mock_dal.add_user.assert_called_once()
mock_dal.create_user_mapping.assert_called_once()
mock_dal.create_user_mapping.assert_not_called()
mock_dal.commit.assert_called_once()
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_duplicate_scim_managed_email_returns_409(
def test_duplicate_email_returns_409(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
@@ -249,12 +248,7 @@ class TestCreateUser:
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
"""409 only when the existing user already has a SCIM mapping."""
existing = make_db_user()
mock_dal.get_user_by_email.return_value = existing
mock_dal.get_user_mapping_by_user_id.return_value = make_user_mapping(
user_id=existing.id
)
mock_dal.get_user_by_email.return_value = make_db_user()
resource = make_scim_user()
result = create_user(
@@ -266,40 +260,6 @@ class TestCreateUser:
assert_scim_error(result, 409)
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_existing_user_without_mapping_gets_linked(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
provider: ScimProvider,
) -> None:
"""Pre-existing user without SCIM mapping gets adopted (linked)."""
existing = make_db_user(email="admin@example.com", personal_name=None)
mock_dal.get_user_by_email.return_value = existing
mock_dal.get_user_mapping_by_user_id.return_value = None
resource = make_scim_user(userName="admin@example.com", externalId="ext-admin")
result = create_user(
user_resource=resource,
_token=mock_token,
provider=provider,
db_session=mock_db_session,
)
parsed = parse_scim_user(result, status=201)
assert parsed.userName == "admin@example.com"
# Should NOT create a new user — reuse existing
mock_dal.add_user.assert_not_called()
# Should sync is_active and personal_name from the SCIM request
mock_dal.update_user.assert_called_once_with(
existing, is_active=True, personal_name="Test User"
)
# Should create a SCIM mapping for the existing user
mock_dal.create_user_mapping.assert_called_once()
mock_dal.commit.assert_called_once()
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_integrity_error_returns_409(
self,

View File

@@ -126,9 +126,7 @@ Resources:
- Effect: Allow
Action:
- secretsmanager:GetSecretValue
Resource:
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret-*
Resource: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
Outputs:
OutputEcsCluster:

View File

@@ -167,12 +167,10 @@ Resources:
- ImportedNamespace: !ImportValue
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
- Name: AUTH_TYPE
Value: basic
Value: disabled
Secrets:
- Name: POSTGRES_PASSWORD
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
- Name: USER_AUTH_SECRET
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
VolumesFrom: []
SystemControls: []

View File

@@ -166,11 +166,9 @@ Resources:
- ImportedNamespace: !ImportValue
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
- Name: AUTH_TYPE
Value: basic
Value: disabled
Secrets:
- Name: POSTGRES_PASSWORD
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
- Name: USER_AUTH_SECRET
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
VolumesFrom: []
SystemControls: []

View File

@@ -78,16 +78,6 @@ const nextConfig = {
},
async rewrites() {
return [
{
source: "/ph_ingest/static/:path*",
destination: "https://us-assets.i.posthog.com/static/:path*",
},
{
source: "/ph_ingest/:path*",
destination: `${
process.env.NEXT_PUBLIC_POSTHOG_HOST || "https://us.i.posthog.com"
}/:path*`,
},
{
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
destination: `${

View File

@@ -1,10 +1,4 @@
:root {
--app-page-main-content-width: 52.5rem;
--block-width-form-input-min: 10rem;
--container-sm: 42rem;
--container-sm-md: 47rem;
--container-md: 54.5rem;
--container-lg: 62rem;
--container-full: 100%;
}

View File

@@ -3,7 +3,9 @@ import posthog from "posthog-js";
import { PostHogProvider } from "posthog-js/react";
import { useEffect } from "react";
const isPostHogEnabled = !!process.env.NEXT_PUBLIC_POSTHOG_KEY;
const isPostHogEnabled = !!(
process.env.NEXT_PUBLIC_POSTHOG_KEY && process.env.NEXT_PUBLIC_POSTHOG_HOST
);
type PHProviderProps = { children: React.ReactNode };
@@ -11,9 +13,7 @@ export function PHProvider({ children }: PHProviderProps) {
useEffect(() => {
if (isPostHogEnabled) {
posthog.init(process.env.NEXT_PUBLIC_POSTHOG_KEY!, {
api_host: "/ph_ingest",
ui_host:
process.env.NEXT_PUBLIC_POSTHOG_HOST || "https://us.posthog.com",
api_host: process.env.NEXT_PUBLIC_POSTHOG_HOST!,
person_profiles: "identified_only",
capture_pageview: false,
session_recording: {

View File

@@ -43,11 +43,8 @@ import { Content } from "@opal/layouts";
import Spacer from "@/refresh-components/Spacer";
const widthClasses = {
sm: "w-[min(var(--container-sm),100%)]",
"sm-md": "w-[min(var(--container-sm-md),100%)]",
md: "w-[min(var(--container-md),100%)]",
lg: "w-[min(var(--container-lg),100%)]",
full: "w-[var(--container-full)]",
md: "w-[min(50rem,100%)]",
lg: "w-[min(60rem,100%)]",
};
/**
@@ -60,19 +57,18 @@ const widthClasses = {
* - Full height container with centered content
* - Automatic overflow-y scrolling
* - Contains the scroll container ID that Settings.Header uses for shadow detection
* - Configurable width via CSS variables defined in sizes.css:
* "sm" (672px), "sm-md" (752px), "md" (872px, default), "lg" (992px), "full" (100%)
* - Configurable width: "md" (50rem max) or "full" (full width with 4rem padding)
*
* @example
* ```tsx
* // Default medium width (872px max)
* // Default medium width (50rem max)
* <SettingsLayouts.Root>
* <SettingsLayouts.Header {...} />
* <SettingsLayouts.Body>...</SettingsLayouts.Body>
* </SettingsLayouts.Root>
*
* // Large width (992px max)
* <SettingsLayouts.Root width="lg">
* // Full width with padding
* <SettingsLayouts.Root width="full">
* <SettingsLayouts.Header {...} />
* <SettingsLayouts.Body>...</SettingsLayouts.Body>
* </SettingsLayouts.Root>

View File

@@ -1,8 +1,7 @@
import { LLMProviderResponse, VisionProvider } from "@/interfaces/llm";
import { LLM_ADMIN_URL } from "@/lib/llmConfig/constants";
import { VisionProvider } from "@/interfaces/llm";
export async function fetchVisionProviders(): Promise<VisionProvider[]> {
const response = await fetch(`${LLM_ADMIN_URL}/vision-providers`, {
const response = await fetch("/api/admin/llm/vision-providers", {
headers: {
"Content-Type": "application/json",
},
@@ -12,24 +11,24 @@ export async function fetchVisionProviders(): Promise<VisionProvider[]> {
`Failed to fetch vision providers: ${await response.text()}`
);
}
const data = (await response.json()) as LLMProviderResponse<VisionProvider>;
return data.providers;
return response.json();
}
export async function setDefaultVisionProvider(
providerId: number,
visionModel: string
): Promise<void> {
const response = await fetch(`${LLM_ADMIN_URL}/default-vision`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider_id: providerId,
model_name: visionModel,
}),
});
const response = await fetch(
`/api/admin/llm/provider/${providerId}/default-vision?vision_model=${encodeURIComponent(
visionModel
)}`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
}
);
if (!response.ok) {
const errorMsg = await response.text();

View File

@@ -6,10 +6,7 @@ import {
ModelConfiguration,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
import {
LLM_ADMIN_URL,
LLM_PROVIDERS_ADMIN_URL,
} from "@/lib/llmConfig/constants";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
import { OnboardingActions, OnboardingState } from "../types";
import { APIFormFieldState } from "@/refresh-components/form/types";
import {
@@ -228,33 +225,13 @@ export function OnboardingFormWrapper<T extends Record<string, any>>({
try {
const newLlmProvider = await response.json();
if (newLlmProvider?.id != null) {
const defaultModelName =
(payload as Record<string, any>).default_model_name ??
(payload as Record<string, any>).model_configurations?.[0]?.name ??
"";
if (!defaultModelName) {
console.error(
"No model name available to set as default — skipping set-default call"
);
} else {
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: defaultModelName,
}),
});
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
setErrorMessage(
err?.detail ?? "Failed to set provider as default"
);
setApiStatus("error");
setIsSubmitting(false);
return;
}
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{ method: "POST" }
);
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
console.error("Failed to set provider as default", err?.detail);
}
}
} catch (_e) {

View File

@@ -412,7 +412,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
}),
} as Response);
// Mock POST /api/admin/llm/default
// Mock POST /api/admin/llm/provider/5/default
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
@@ -431,20 +431,14 @@ describe("Custom LLM Provider Configuration Workflow", () => {
const submitButton = screen.getByRole("button", { name: /enable/i });
await user.click(submitButton);
// Verify set as default API was called with correct endpoint and body
// Verify set as default API was called
await waitFor(() => {
const defaultCall = fetchSpy.mock.calls.find(
([url]) => url === "/api/admin/llm/default"
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/provider/5/default",
expect.objectContaining({
method: "POST",
})
);
expect(defaultCall).toBeDefined();
const [, options] = defaultCall!;
expect(options.method).toBe("POST");
expect(options.headers).toEqual({ "Content-Type": "application/json" });
const body = JSON.parse(options.body);
expect(body.provider_id).toBe(5);
expect(body).toHaveProperty("model_name");
});
});

View File

@@ -20,7 +20,7 @@ async function deleteAllProviders(client: OnyxApiClient): Promise<void> {
const providers = await client.listLlmProviders();
for (const provider of providers) {
try {
await client.deleteProvider(provider.id);
await client.deleteProvider(provider.id, { force: true });
} catch (error) {
console.warn(
`Failed to delete provider ${provider.id}: ${String(error)}`

View File

@@ -526,8 +526,14 @@ export class OnyxApiClient {
*
* @param providerId - The provider ID to delete
*/
async deleteProvider(providerId: number): Promise<void> {
const response = await this.delete(`/admin/llm/provider/${providerId}`);
async deleteProvider(
providerId: number,
{ force = false }: { force?: boolean } = {}
): Promise<void> {
const query = force ? "?force=true" : "";
const response = await this.delete(
`/admin/llm/provider/${providerId}${query}`
);
await this.handleResponseSoft(
response,