Compare commits

..

23 Commits

Author SHA1 Message Date
Evan Lohn
de7fc36fc5 test: no vector db user file processing (#8854) 2026-02-28 04:19:59 +00:00
Evan Lohn
7f9e37450d fix: non vector db tasks (#8849) 2026-02-28 03:51:57 +00:00
Evan Lohn
c7ef85b733 chore: narrow no_vector_db supported scope (#8847) 2026-02-28 02:54:15 +00:00
Danelegend
bd9319e592 feat: LLM Provider Rework (#8761)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-02-28 01:29:49 +00:00
Nikolas Garza
db5955d6f2 fix(ee): show Access Restricted page when seat limit exceeded (#8877) 2026-02-28 01:26:00 +00:00
Raunak Bhagat
5e447440ea refactor(Suggestions): migrate to opal Interactive + Content (#8881) 2026-02-27 23:39:20 +00:00
Justin Tahara
78c6ca39b8 fix(minio): No cURL in minio container (#8876) 2026-02-27 22:37:42 +00:00
Raunak Bhagat
71a7cf09b3 refactor(opal): migrate LineItemLayout to Content/ContentAction (#8824) 2026-02-27 22:27:09 +00:00
dependabot[bot]
91d30a0156 chore(deps): bump actions/download-artifact from 4.2.1 to 7.0.0 (#8474)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-27 22:11:03 +00:00
dependabot[bot]
7b30752767 chore(deps): bump rollup from 4.52.5 to 4.59.0 in /web (#8782)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-27 21:57:10 +00:00
Justin Tahara
4450ecf07c fix(gong): Respecting Retry Timeout Header (#8866) 2026-02-27 21:45:31 +00:00
Danelegend
0e6b766996 feat: Add python tool as default for default persona (#8857) 2026-02-27 21:32:55 +00:00
dependabot[bot]
12c8cd338b chore(deps): bump werkzeug from 3.1.5 to 3.1.6 (#8615)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-27 21:08:33 +00:00
dependabot[bot]
ad5688bf65 chore(deps-dev): bump rollup from 4.55.1 to 4.59.0 in /widget (#8863)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-27 21:02:20 +00:00
Jamison Lahman
d2deefd1f1 chore(whitelabeling): always show sidebar icon without logo icon (#8860)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-27 20:36:11 +00:00
Jamison Lahman
18b90d405d chore(deps): upgrade fastapi: 0.128.0->0.133.1 (#8862) 2026-02-27 20:26:27 +00:00
Raunak Bhagat
8394e8837b feat(opal): extract widthVariant to shared and add to Content (#8859) 2026-02-27 19:50:32 +00:00
Jamison Lahman
f06df891c4 chore(fe): InputSelect has a min-width (#8858) 2026-02-27 19:20:37 +00:00
Wenxi
d6d5e72c18 feat(ods): whois utility to find tenant_ids and admin emails (#8855) 2026-02-27 18:21:29 +00:00
Danelegend
449f5d62f9 fix: Code output extending over thinking bounds (#8837) 2026-02-27 08:26:54 +00:00
Yuhong Sun
4d256c5666 chore: remove instance of Assistant from frontend (#8848)
Co-authored-by: Nik <nikolas.garza5@gmail.com>
2026-02-27 04:22:28 +00:00
Danelegend
2e53496f46 feat: Code interpreter admin page visuals (#8729) 2026-02-27 04:01:02 +00:00
acaprau
63a206706a docs(best practices): Add comment about import-time side effects and main.py files (#8820)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-27 01:29:56 +00:00
245 changed files with 4948 additions and 5304 deletions

View File

@@ -114,8 +114,10 @@ jobs:
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
env:
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
exit 1
notify-slack-on-cherry-pick-failure:

View File

@@ -603,7 +603,7 @@ jobs:
pull-requests: write
steps:
- name: Download visual diff summaries
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
with:
pattern: screenshot-diff-summary-*
path: summaries/

View File

@@ -1,37 +0,0 @@
"""add cache_store table
Revision ID: 2664261bfaab
Revises: c0c937d5c9e5
Create Date: 2026-02-27 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2664261bfaab"
down_revision = "c0c937d5c9e5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"cache_store",
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.LargeBinary(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("key"),
)
op.create_index(
"ix_cache_store_expires",
"cache_store",
["expires_at"],
postgresql_where=sa.text("expires_at IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index("ix_cache_store_expires", table_name="cache_store")
op.drop_table("cache_store")

View File

@@ -0,0 +1,69 @@
"""add python tool on default
Revision ID: 57122d037335
Revises: c0c937d5c9e5
Create Date: 2026-02-27 10:10:40.124925
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "57122d037335"
down_revision = "c0c937d5c9e5"
branch_labels = None
depends_on = None
PYTHON_TOOL_NAME = "python"
def upgrade() -> None:
conn = op.get_bind()
# Look up the PythonTool id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE name = :name"),
{"name": PYTHON_TOOL_NAME},
).fetchone()
if not result:
return
tool_id = result[0]
# Attach to the default persona (id=0) if not already attached
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": tool_id},
)
def downgrade() -> None:
conn = op.get_bind()
result = conn.execute(
sa.text("SELECT id FROM tool WHERE name = :name"),
{"name": PYTHON_TOOL_NAME},
).fetchone()
if not result:
return
conn.execute(
sa.text(
"""
DELETE FROM persona__tool
WHERE persona_id = 0 AND tool_id = :tool_id
"""
),
{"tool_id": result[0]},
)

View File

@@ -20,6 +20,7 @@ from ee.onyx.server.enterprise_settings.store import (
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -117,15 +118,38 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
if llm_upsert_requests:
logger.notice("Seeding LLMs")
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
if not llm_upsert_requests:
return
logger.notice("Seeding LLMs")
for request in llm_upsert_requests:
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
if existing:
request.id = existing.id
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
default_provider = next(
(p for p in seeded_providers if p.model_configurations), None
)
if not default_provider:
return
visible_configs = [
mc for mc in default_provider.model_configurations if mc.is_visible
]
default_config = (
visible_configs[0]
if visible_configs
else default_provider.model_configurations[0]
)
update_default_provider(
provider_id=default_provider.id,
model_name=default_config.name,
db_session=db_session,
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:

View File

@@ -109,6 +109,12 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
settings.ee_features_enabled = False
elif metadata.used_seats > metadata.seats:
# License is valid but seat limit exceeded
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
settings.seat_count = metadata.seats
settings.used_seats = metadata.used_seats
settings.ee_features_enabled = True
else:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True

View File

@@ -33,6 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
@@ -302,12 +303,17 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest) -> None:
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
nonlocal has_set_default_provider
try:
existing = fetch_existing_llm_provider(
name=request.name, db_session=db_session
)
if existing:
request.id = existing.id
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, default_model, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -325,14 +331,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider)
_upsert(openai_provider, default_model_name)
# Create default image generation config using the OpenAI API key
try:
@@ -361,14 +366,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider)
_upsert(anthropic_provider, default_model_name)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -393,14 +397,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider)
_upsert(vertexai_provider, default_model_name)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -432,12 +435,11 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider)
_upsert(openrouter_provider, default_model_name)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -434,7 +434,7 @@ def _process_user_file_impl(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"_process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
@@ -602,7 +602,7 @@ def _delete_user_file_impl(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"_delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
@@ -778,7 +778,7 @@ def _project_sync_user_file_impl(
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"_project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)

View File

@@ -1,276 +0,0 @@
"""Periodic poller for NO_VECTOR_DB deployments.
Replaces Celery Beat and background workers with a lightweight daemon thread
that runs from the API server process. Two responsibilities:
1. Recovery polling (every 30 s): re-processes user files stuck in
PROCESSING / DELETING / needs_sync states via the drain loops defined
in ``task_utils.py``.
2. Periodic task execution (configurable intervals): runs LLM model updates
and scheduled evals at their configured cadences, with Postgres advisory
lock deduplication across multiple API server instances.
"""
import threading
import time
from collections.abc import Callable
from dataclasses import dataclass
from dataclasses import field
from sqlalchemy import text
from onyx.utils.logger import setup_logger
logger = setup_logger()
RECOVERY_INTERVAL_SECONDS = 30
PERIODIC_TASK_LOCK_BASE = 20_000
# ------------------------------------------------------------------
# Periodic task definitions
# ------------------------------------------------------------------
@dataclass
class _PeriodicTaskDef:
name: str
interval_seconds: float
lock_id: int
run_fn: Callable[[], None]
last_run_at: float = field(default=0.0)
def _run_auto_llm_update() -> None:
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
if not AUTO_LLM_CONFIG_URL:
return
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.well_known_providers.auto_update_service import (
sync_llm_models_from_github,
)
with get_session_with_current_tenant() as db_session:
sync_llm_models_from_github(db_session)
def _run_cache_cleanup() -> None:
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
cleanup_expired_cache_entries()
def _run_scheduled_eval() -> None:
from onyx.configs.app_configs import BRAINTRUST_API_KEY
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
if not all(
[
BRAINTRUST_API_KEY,
SCHEDULED_EVAL_PROJECT,
SCHEDULED_EVAL_DATASET_NAMES,
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
]
):
return
from datetime import datetime
from datetime import timezone
from onyx.evals.eval import run_eval
from onyx.evals.models import EvalConfigurationOptions
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
try:
run_eval(
configuration=EvalConfigurationOptions(
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
dataset_name=dataset_name,
no_send_logs=False,
braintrust_project=SCHEDULED_EVAL_PROJECT,
experiment_name=f"{dataset_name} - {run_timestamp}",
),
remote_dataset_name=dataset_name,
)
except Exception:
logger.exception(
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
)
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
tasks: list[_PeriodicTaskDef] = []
if CACHE_BACKEND == CacheBackendType.POSTGRES:
tasks.append(
_PeriodicTaskDef(
name="cache-cleanup",
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
run_fn=_run_cache_cleanup,
)
)
if AUTO_LLM_CONFIG_URL:
tasks.append(
_PeriodicTaskDef(
name="auto-llm-update",
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
lock_id=PERIODIC_TASK_LOCK_BASE,
run_fn=_run_auto_llm_update,
)
)
if SCHEDULED_EVAL_DATASET_NAMES:
tasks.append(
_PeriodicTaskDef(
name="scheduled-eval",
interval_seconds=7 * 24 * 3600,
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
run_fn=_run_scheduled_eval,
)
)
return tasks
# ------------------------------------------------------------------
# Periodic task runner with advisory lock dedup
# ------------------------------------------------------------------
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
now = time.monotonic()
if now - task_def.last_run_at < task_def.interval_seconds:
return
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
acquired = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": task_def.lock_id},
).scalar()
if not acquired:
return
try:
task_def.run_fn()
task_def.last_run_at = now
except Exception:
logger.exception(
f"Periodic poller - Error running periodic task {task_def.name}"
)
finally:
db_session.execute(
text("SELECT pg_advisory_unlock(:id)"),
{"id": task_def.lock_id},
)
# ------------------------------------------------------------------
# Recovery / drain loop runner
# ------------------------------------------------------------------
def _run_drain_loops(tenant_id: str) -> None:
from onyx.background.task_utils import drain_delete_loop
from onyx.background.task_utils import drain_processing_loop
from onyx.background.task_utils import drain_project_sync_loop
drain_processing_loop(tenant_id)
drain_delete_loop(tenant_id)
drain_project_sync_loop(tenant_id)
# ------------------------------------------------------------------
# Startup recovery (10g)
# ------------------------------------------------------------------
def recover_stuck_user_files(tenant_id: str) -> None:
"""Run all drain loops once to re-process files left in intermediate states.
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
"""
logger.info("recover_stuck_user_files - Checking for stuck user files")
try:
_run_drain_loops(tenant_id)
except Exception:
logger.exception("recover_stuck_user_files - Error during recovery")
# ------------------------------------------------------------------
# Daemon thread (10f)
# ------------------------------------------------------------------
_shutdown_event = threading.Event()
_poller_thread: threading.Thread | None = None
def _poller_loop(tenant_id: str) -> None:
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
periodic_tasks = _build_periodic_tasks()
logger.info(
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
f"{[t.name for t in periodic_tasks]}"
)
while not _shutdown_event.is_set():
try:
_run_drain_loops(tenant_id)
except Exception:
logger.exception("Periodic poller - Error in recovery polling")
for task_def in periodic_tasks:
try:
_try_run_periodic_task(task_def)
except Exception:
logger.exception(
f"Periodic poller - Unhandled error checking task {task_def.name}"
)
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
def start_periodic_poller(tenant_id: str) -> None:
"""Start the periodic poller daemon thread."""
global _poller_thread # noqa: PLW0603
_shutdown_event.clear()
_poller_thread = threading.Thread(
target=_poller_loop,
args=(tenant_id,),
daemon=True,
name="no-vectordb-periodic-poller",
)
_poller_thread.start()
logger.info("Periodic poller thread started")
def stop_periodic_poller() -> None:
"""Signal the periodic poller to stop and wait for it to exit."""
global _poller_thread # noqa: PLW0603
if _poller_thread is None:
return
_shutdown_event.set()
_poller_thread.join(timeout=10)
if _poller_thread.is_alive():
logger.warning("Periodic poller thread did not stop within timeout")
_poller_thread = None
logger.info("Periodic poller thread stopped")

View File

@@ -1,35 +1,3 @@
"""Background task utilities.
Contains query-history report helpers (used by all deployment modes) and
in-process background task execution helpers for NO_VECTOR_DB mode:
- Postgres advisory lock-based concurrency semaphore (10d)
- Drain loops that process all pending user file work (10e)
- Entry points wired to FastAPI BackgroundTasks (10c)
Advisory locks are session-level: they persist until explicitly released via
``pg_advisory_unlock`` or until the DB connection closes. The semaphore
session is kept open for the entire drain loop so the slot stays held, and
released in a ``finally`` block before the connection returns to the pool.
"""
from uuid import UUID
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.utils.logger import setup_logger
logger = setup_logger()
# ------------------------------------------------------------------
# Query-history report helpers (pre-existing, used by all modes)
# ------------------------------------------------------------------
QUERY_REPORT_NAME_PREFIX = "query-history"
@@ -41,173 +9,3 @@ def construct_query_history_report_name(
def extract_task_id_from_query_history_report_name(name: str) -> str:
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
# ------------------------------------------------------------------
# Postgres advisory lock semaphore (NO_VECTOR_DB mode)
# ------------------------------------------------------------------
BACKGROUND_TASK_SLOT_BASE = 10_000
BACKGROUND_TASK_MAX_CONCURRENCY = 4
def try_acquire_semaphore_slot(db_session: Session) -> int | None:
"""Try to acquire one of N advisory lock slots.
Returns the slot number (0-based) if acquired, ``None`` if all slots are
taken. ``pg_try_advisory_lock`` is non-blocking — returns ``false``
immediately when the lock is held by another session.
"""
for slot in range(BACKGROUND_TASK_MAX_CONCURRENCY):
lock_id = BACKGROUND_TASK_SLOT_BASE + slot
acquired = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": lock_id},
).scalar()
if acquired:
return slot
return None
def release_semaphore_slot(db_session: Session, slot: int) -> None:
"""Release a previously acquired advisory lock slot."""
lock_id = BACKGROUND_TASK_SLOT_BASE + slot
db_session.execute(
text("SELECT pg_advisory_unlock(:id)"),
{"id": lock_id},
)
# ------------------------------------------------------------------
# Work-claiming helpers (FOR UPDATE SKIP LOCKED)
# ------------------------------------------------------------------
def _claim_next_processing_file(db_session: Session) -> UUID | None:
"""Claim the next file in PROCESSING status."""
return db_session.execute(
select(UserFile.id)
.where(UserFile.status == UserFileStatus.PROCESSING)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
).scalar_one_or_none()
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
"""Claim the next file in DELETING status."""
return db_session.execute(
select(UserFile.id)
.where(UserFile.status == UserFileStatus.DELETING)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
).scalar_one_or_none()
def _claim_next_sync_file(db_session: Session) -> UUID | None:
"""Claim the next file needing project/persona sync."""
return db_session.execute(
select(UserFile.id)
.where(
sa.and_(
sa.or_(
UserFile.needs_project_sync.is_(True),
UserFile.needs_persona_sync.is_(True),
),
UserFile.status == UserFileStatus.COMPLETED,
)
)
.order_by(UserFile.created_at)
.limit(1)
.with_for_update(skip_locked=True)
).scalar_one_or_none()
# ------------------------------------------------------------------
# Drain loops — acquire a semaphore slot then process *all* pending work
# ------------------------------------------------------------------
def drain_processing_loop(tenant_id: str) -> None:
"""Process all pending PROCESSING user files."""
from onyx.background.celery.tasks.user_file_processing.tasks import (
_process_user_file_impl,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as sem_session:
slot = try_acquire_semaphore_slot(sem_session)
if slot is None:
logger.info("drain_processing_loop - All semaphore slots taken, skipping")
return
try:
while True:
with get_session_with_current_tenant() as claim_session:
file_id = _claim_next_processing_file(claim_session)
if file_id is None:
break
_process_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
finally:
release_semaphore_slot(sem_session, slot)
def drain_delete_loop(tenant_id: str) -> None:
"""Delete all pending DELETING user files."""
from onyx.background.celery.tasks.user_file_processing.tasks import (
_delete_user_file_impl,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as sem_session:
slot = try_acquire_semaphore_slot(sem_session)
if slot is None:
logger.info("drain_delete_loop - All semaphore slots taken, skipping")
return
try:
while True:
with get_session_with_current_tenant() as claim_session:
file_id = _claim_next_deleting_file(claim_session)
if file_id is None:
break
_delete_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
finally:
release_semaphore_slot(sem_session, slot)
def drain_project_sync_loop(tenant_id: str) -> None:
"""Sync all pending project/persona metadata for user files."""
from onyx.background.celery.tasks.user_file_processing.tasks import (
_project_sync_user_file_impl,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as sem_session:
slot = try_acquire_semaphore_slot(sem_session)
if slot is None:
logger.info("drain_project_sync_loop - All semaphore slots taken, skipping")
return
try:
while True:
with get_session_with_current_tenant() as claim_session:
file_id = _claim_next_sync_file(claim_session)
if file_id is None:
break
_project_sync_user_file_impl(
user_file_id=str(file_id),
tenant_id=tenant_id,
redis_locking=False,
)
finally:
release_semaphore_slot(sem_session, slot)

View File

@@ -1,51 +0,0 @@
from collections.abc import Callable
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import CACHE_BACKEND
def _build_redis_backend(tenant_id: str) -> CacheBackend:
from onyx.cache.redis_backend import RedisCacheBackend
from onyx.redis.redis_pool import redis_pool
return RedisCacheBackend(redis_pool.get_client(tenant_id))
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
from onyx.cache.postgres_backend import PostgresCacheBackend
return PostgresCacheBackend(tenant_id)
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
CacheBackendType.REDIS: _build_redis_backend,
CacheBackendType.POSTGRES: _build_postgres_backend,
}
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
"""Return a tenant-aware ``CacheBackend``.
If *tenant_id* is ``None``, the current tenant is read from the
thread-local context variable (same behaviour as ``get_redis_client``).
"""
if tenant_id is None:
from shared_configs.contextvars import get_current_tenant_id
tenant_id = get_current_tenant_id()
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
if builder is None:
raise ValueError(
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
f"Supported values: {[t.value for t in CacheBackendType]}"
)
return builder(tenant_id)
def get_shared_cache_backend() -> CacheBackend:
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
from shared_configs.configs import DEFAULT_REDIS_PREFIX
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)

View File

@@ -1,96 +0,0 @@
import abc
from enum import Enum
class CacheBackendType(str, Enum):
REDIS = "redis"
POSTGRES = "postgres"
class CacheLock(abc.ABC):
"""Abstract distributed lock returned by CacheBackend.lock()."""
@abc.abstractmethod
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
raise NotImplementedError
@abc.abstractmethod
def release(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def owned(self) -> bool:
raise NotImplementedError
def __enter__(self) -> "CacheLock":
self.acquire()
return self
def __exit__(self, *args: object) -> None:
self.release()
class CacheBackend(abc.ABC):
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
Covers the subset of Redis operations used outside of Celery. When
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
"""
# -- basic key/value ---------------------------------------------------
@abc.abstractmethod
def get(self, key: str) -> bytes | None:
raise NotImplementedError
@abc.abstractmethod
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def delete(self, key: str) -> None:
raise NotImplementedError
@abc.abstractmethod
def exists(self, key: str) -> bool:
raise NotImplementedError
# -- TTL ---------------------------------------------------------------
@abc.abstractmethod
def expire(self, key: str, seconds: int) -> None:
raise NotImplementedError
@abc.abstractmethod
def ttl(self, key: str) -> int:
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
raise NotImplementedError
# -- distributed lock --------------------------------------------------
@abc.abstractmethod
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
raise NotImplementedError
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
@abc.abstractmethod
def rpush(self, key: str, value: str | bytes) -> None:
raise NotImplementedError
@abc.abstractmethod
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
"""Block until a value is available on one of *keys*, or *timeout* expires.
Returns ``(key, value)`` or ``None`` on timeout.
"""
raise NotImplementedError

View File

@@ -1,293 +0,0 @@
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
for distributed locking, and a polling loop for the BLPOP pattern.
"""
import hashlib
import struct
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import Connection
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.db.models import CacheStore
_LIST_KEY_PREFIX = "_q:"
_LIST_ITEM_TTL_SECONDS = 3600
_BLPOP_POLL_INTERVAL = 0.25
def _list_item_key(key: str) -> str:
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}"
def _to_bytes(value: str | bytes | int | float) -> bytes:
if isinstance(value, bytes):
return value
return str(value).encode()
# ------------------------------------------------------------------
# Lock
# ------------------------------------------------------------------
class PostgresCacheLock(CacheLock):
"""Advisory-lock-based distributed lock.
The lock is tied to a dedicated database connection. Releasing
the lock (or closing the connection) frees it.
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
``timeout`` seconds. They are released when ``release()`` is
called or when the underlying connection is closed.
"""
def __init__(self, lock_id: int, timeout: float | None) -> None:
self._lock_id = lock_id
self._timeout = timeout
self._conn: Connection | None = None
self._acquired = False
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
self._conn = get_sqlalchemy_engine().connect()
if not blocking:
if self._try_lock():
return True
self._conn.close()
self._conn = None
return False
effective_timeout = blocking_timeout or self._timeout
deadline = (time.monotonic() + effective_timeout) if effective_timeout else None
while True:
if self._try_lock():
return True
if deadline is not None and time.monotonic() >= deadline:
self._conn.close()
self._conn = None
return False
time.sleep(0.1)
def release(self) -> None:
if not self._acquired or self._conn is None:
return
try:
self._conn.execute(
text("SELECT pg_advisory_unlock(:id)"), {"id": self._lock_id}
)
finally:
self._acquired = False
self._conn.close()
self._conn = None
def owned(self) -> bool:
return self._acquired
def _try_lock(self) -> bool:
assert self._conn is not None
result = self._conn.execute(
text("SELECT pg_try_advisory_lock(:id)"), {"id": self._lock_id}
).scalar()
if result:
self._acquired = True
return True
return False
# ------------------------------------------------------------------
# Backend
# ------------------------------------------------------------------
class PostgresCacheBackend(CacheBackend):
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
Each operation opens and closes its own database session so the backend
is safe to share across threads. Tenant isolation is handled by
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
"""
def __init__(self, tenant_id: str) -> None:
self._tenant_id = tenant_id
# -- basic key/value ---------------------------------------------------
def get(self, key: str) -> bytes | None:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.value).where(
CacheStore.key == key,
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
value = session.execute(stmt).scalar_one_or_none()
if value is None:
return None
return bytes(value)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
value_bytes = _to_bytes(value)
expires_at = datetime.now(timezone.utc) + timedelta(seconds=ex) if ex else None
stmt = (
pg_insert(CacheStore)
.values(key=key, value=value_bytes, expires_at=expires_at)
.on_conflict_do_update(
index_elements=[CacheStore.key],
set_={"value": value_bytes, "expires_at": expires_at},
)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def delete(self, key: str) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(delete(CacheStore).where(CacheStore.key == key))
session.commit()
def exists(self, key: str) -> bool:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = (
select(CacheStore.key)
.where(
CacheStore.key == key,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.limit(1)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
return session.execute(stmt).first() is not None
# -- TTL ---------------------------------------------------------------
def expire(self, key: str, seconds: int) -> None:
from onyx.db.engine.sql_engine import get_session_with_tenant
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
stmt = (
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
session.execute(stmt)
session.commit()
def ttl(self, key: str) -> int:
from onyx.db.engine.sql_engine import get_session_with_tenant
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
result = session.execute(stmt).first()
if result is None:
return -2
expires_at: datetime | None = result[0]
if expires_at is None:
return -1
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
if remaining <= 0:
return -2
return int(remaining)
# -- distributed lock --------------------------------------------------
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
return PostgresCacheLock(self._lock_id_for(name), timeout)
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
def rpush(self, key: str, value: str | bytes) -> None:
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
from onyx.db.engine.sql_engine import get_session_with_tenant
deadline = (time.monotonic() + timeout) if timeout > 0 else None
while True:
for key in keys:
lower = f"{_LIST_KEY_PREFIX}{key}:"
upper = f"{_LIST_KEY_PREFIX}{key};"
stmt = (
select(CacheStore)
.where(
CacheStore.key >= lower,
CacheStore.key < upper,
or_(
CacheStore.expires_at.is_(None),
CacheStore.expires_at > func.now(),
),
)
.order_by(CacheStore.key)
.limit(1)
.with_for_update(skip_locked=True)
)
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
row = session.execute(stmt).scalars().first()
if row is not None:
value = bytes(row.value) if row.value else b""
session.delete(row)
session.commit()
return (key.encode(), value)
if deadline is not None and time.monotonic() >= deadline:
return None
time.sleep(_BLPOP_POLL_INTERVAL)
# -- helpers -----------------------------------------------------------
def _lock_id_for(self, name: str) -> int:
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
return struct.unpack("q", h[:8])[0]
# ------------------------------------------------------------------
# Periodic cleanup
# ------------------------------------------------------------------
def cleanup_expired_cache_entries() -> None:
"""Delete rows whose ``expires_at`` is in the past.
Called by the periodic poller every 5 minutes.
"""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as session:
session.execute(
delete(CacheStore).where(
CacheStore.expires_at.is_not(None),
CacheStore.expires_at < func.now(),
)
)
session.commit()

View File

@@ -1,92 +0,0 @@
from typing import cast
from redis.client import Redis
from redis.lock import Lock as RedisLock
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
class RedisCacheLock(CacheLock):
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
def __init__(self, lock: RedisLock) -> None:
self._lock = lock
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
return bool(
self._lock.acquire(
blocking=blocking,
blocking_timeout=blocking_timeout,
)
)
def release(self) -> None:
self._lock.release()
def owned(self) -> bool:
return bool(self._lock.owned())
class RedisCacheBackend(CacheBackend):
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
This is a thin pass-through — every method maps 1-to-1 to the underlying
Redis command. ``TenantRedis`` key-prefixing is handled by the client
itself (provided by ``get_redis_client``).
"""
def __init__(self, redis_client: Redis) -> None:
self._r = redis_client
# -- basic key/value ---------------------------------------------------
def get(self, key: str) -> bytes | None:
val = self._r.get(key)
if val is None:
return None
if isinstance(val, bytes):
return val
return str(val).encode()
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
self._r.set(key, value, ex=ex)
def delete(self, key: str) -> None:
self._r.delete(key)
def exists(self, key: str) -> bool:
return bool(self._r.exists(key))
# -- TTL ---------------------------------------------------------------
def expire(self, key: str, seconds: int) -> None:
self._r.expire(key, seconds)
def ttl(self, key: str) -> int:
return cast(int, self._r.ttl(key))
# -- distributed lock --------------------------------------------------
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
return RedisCacheLock(self._r.lock(name, timeout=timeout))
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
def rpush(self, key: str, value: str | bytes) -> None:
self._r.rpush(key, value)
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
if result is None:
return None
return (result[0], result[1])

View File

@@ -1,52 +1,57 @@
from uuid import UUID
from onyx.cache.interface import CacheBackend
from redis.client import Redis
# Redis key prefixes for chat message processing
PREFIX = "chatprocessing"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 30 * 60 # 30 minutes
def _get_fence_key(chat_session_id: UUID) -> str:
"""Generate the cache key for a chat session processing fence.
"""
Generate the Redis key for a chat session processing a message.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string. Tenant isolation is handled automatically
by the cache backend (Redis key-prefixing or Postgres schema routing).
The fence key string (tenant_id is automatically added by the Redis client)
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_processing_status(
chat_session_id: UUID, cache: CacheBackend, value: bool
chat_session_id: UUID, redis_client: Redis, value: bool
) -> None:
"""Set or clear the fence for a chat session processing a message.
"""
Set or clear the fence for a chat session processing a message.
If the key exists, a message is being processed.
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
Args:
chat_session_id: The UUID of the chat session
cache: Tenant-aware cache backend
redis_client: The Redis client to use
value: True to set the fence, False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if value:
cache.set(fence_key, 0, ex=FENCE_TTL)
redis_client.set(fence_key, 0, ex=FENCE_TTL)
else:
cache.delete(fence_key)
redis_client.delete(fence_key)
def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool:
"""Check if the chat session is processing a message.
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session is processing a message.
Args:
chat_session_id: The UUID of the chat session
cache: Tenant-aware cache backend
redis_client: The Redis client to use
Returns:
True if the chat session is processing a message, False otherwise
"""
return cache.exists(_get_fence_key(chat_session_id))
fence_key = _get_fence_key(chat_session_id)
return bool(redis_client.exists(fence_key))

View File

@@ -11,10 +11,9 @@ from contextvars import Token
from uuid import UUID
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.cache.factory import get_cache_backend
from onyx.cache.interface import CacheBackend
from onyx.chat.chat_processing_checker import set_processing_status
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import run_chat_loop_with_state_containers
@@ -80,6 +79,7 @@ from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
@@ -448,7 +448,7 @@ def handle_stream_message_objects(
llm: LLM | None = None
chat_session: ChatSession | None = None
cache: CacheBackend | None = None
redis_client: Redis | None = None
user_id = user.id
if user.is_anonymous:
@@ -809,19 +809,19 @@ def handle_stream_message_objects(
)
simple_chat_history.insert(0, summary_simple)
cache = get_cache_backend()
redis_client = get_redis_client()
reset_cancel_status(
chat_session.id,
cache,
redis_client,
)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, cache)
return check_stop_signal(chat_session.id, redis_client)
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
redis_client=redis_client,
value=True,
)
@@ -968,10 +968,10 @@ def handle_stream_message_objects(
reset_llm_mock_response(mock_response_token)
try:
if cache is not None and chat_session is not None:
if redis_client is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
redis_client=redis_client,
value=False,
)
except Exception:

View File

@@ -1,58 +1,65 @@
from uuid import UUID
from onyx.cache.interface import CacheBackend
from redis.client import Redis
# Redis key prefixes for chat session stop signals
PREFIX = "chatsessionstop"
FENCE_PREFIX = f"{PREFIX}_fence"
FENCE_TTL = 10 * 60 # 10 minutes
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
def _get_fence_key(chat_session_id: UUID) -> str:
"""Generate the cache key for a chat session stop signal fence.
"""
Generate the Redis key for a chat session stop signal fence.
Args:
chat_session_id: The UUID of the chat session
Returns:
The fence key string. Tenant isolation is handled automatically
by the cache backend (Redis key-prefixing or Postgres schema routing).
The fence key string (tenant_id is automatically added by the Redis client)
"""
return f"{FENCE_PREFIX}_{chat_session_id}"
def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None:
"""Set or clear the stop signal fence for a chat session.
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
"""
Set or clear the stop signal fence for a chat session.
Args:
chat_session_id: The UUID of the chat session
cache: Tenant-aware cache backend
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
value: True to set the fence (stop signal), False to clear it
"""
fence_key = _get_fence_key(chat_session_id)
if not value:
cache.delete(fence_key)
redis_client.delete(fence_key)
return
cache.set(fence_key, 0, ex=FENCE_TTL)
redis_client.set(fence_key, 0, ex=FENCE_TTL)
def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool:
"""Check if the chat session should continue (not stopped).
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session should continue (not stopped).
Args:
chat_session_id: The UUID of the chat session to check
cache: Tenant-aware cache backend
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
Returns:
True if the session should continue, False if it should stop
"""
return not cache.exists(_get_fence_key(chat_session_id))
fence_key = _get_fence_key(chat_session_id)
return not bool(redis_client.exists(fence_key))
def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None:
"""Clear the stop signal for a chat session.
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
"""
Clear the stop signal for a chat session.
Args:
chat_session_id: The UUID of the chat session
cache: Tenant-aware cache backend
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
"""
cache.delete(_get_fence_key(chat_session_id))
fence_key = _get_fence_key(chat_session_id)
redis_client.delete(fence_key)

View File

@@ -6,7 +6,6 @@ from datetime import timezone
from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.cache.interface import CacheBackendType
from onyx.configs.constants import AuthType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
@@ -55,12 +54,6 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
# are disabled but core chat, tools, user file uploads, and Projects still work.
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
# Which backend to use for caching, locks, and ephemeral state.
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000

View File

@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
# Gong API limit is 3 calls/sec — stay safely under it
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
def __init__(
self,
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
self._last_request_time: float = 0.0
# urllib3 Retry already respects the Retry-After header by default
# (respect_retry_after_header=True), so on 429 it will sleep for the
# duration Gong specifies before retrying.
retry_strategy = Retry(
total=5,
total=10,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
def _throttled_request(
self, method: str, url: str, **kwargs: Any
) -> requests.Response:
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
now = time.monotonic()
elapsed = now - self._last_request_time
if elapsed < self.MIN_REQUEST_INTERVAL:
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
response = self._session.request(method, url, **kwargs)
self._last_request_time = time.monotonic()
return response
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
response = self._throttled_request(
"GET", GongConnector.make_url("/v2/workspaces")
)
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
)
response.raise_for_status()
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes
# in ~2 minutes. After max attempts, proceed with whatever we have —
# the per-call loop below will skip missing IDs gracefully.
current_attempt = 0
while True:
current_attempt += 1
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
raise RuntimeError(
f"Attempt count exceeded for _get_call_details_by_ids: "
f"missing_call_ids={missing_call_ids} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_call_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(transcript_call_ids)} calls"
)
break
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(

View File

@@ -202,7 +202,6 @@ def create_default_image_gen_config_from_api_key(
api_key=api_key,
api_base=None,
api_version=None,
default_model_name=model_name,
deployment_name=None,
is_public=True,
)

View File

@@ -213,11 +213,29 @@ def upsert_llm_provider(
llm_provider_upsert_request: LLMProviderUpsertRequest,
db_session: Session,
) -> LLMProviderView:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
existing_llm_provider: LLMProviderModel | None = None
if llm_provider_upsert_request.id:
existing_llm_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
if not existing_llm_provider:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} not found"
)
if not existing_llm_provider:
if existing_llm_provider.name != llm_provider_upsert_request.name:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
)
else:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if existing_llm_provider:
raise ValueError(
f"LLM provider with name '{llm_provider_upsert_request.name}'"
" already exists"
)
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
db_session.add(existing_llm_provider)
@@ -238,11 +256,7 @@ def upsert_llm_provider(
existing_llm_provider.api_base = api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
# TODO: Remove default model name on api change
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
@@ -306,15 +320,6 @@ def upsert_llm_provider(
display_name=model_config.display_name,
)
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
_update_default_model(
db_session=db_session,
provider_id=existing_llm_provider.id,
model=existing_llm_provider.default_model_name,
flow_type=LLMModelFlowType.CHAT,
)
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
@@ -488,6 +493,22 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.id == id)
.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
)
return provider_model
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:
@@ -604,22 +625,13 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
db_session.flush()
def update_default_provider(provider_id: int, db_session: Session) -> None:
# Attempt to get the default_model_name from the provider first
# TODO: Remove default_model_name check
provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.id == provider_id,
)
)
if provider is None:
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
def update_default_provider(
provider_id: int, model_name: str, db_session: Session
) -> None:
_update_default_model(
db_session,
provider_id,
provider.default_model_name, # type: ignore[arg-type]
model_name,
LLMModelFlowType.CHAT,
)
@@ -805,12 +817,6 @@ def sync_auto_mode_models(
)
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
changes += 1
db_session.commit()
return changes

View File

@@ -5000,25 +5000,3 @@ class CodeInterpreterServer(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
class CacheStore(Base):
"""Key-value cache table used by ``PostgresCacheBackend``.
Replaces Redis for simple KV caching, locks, and list operations
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
Intentionally separate from ``KVStore``:
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
- Holds ephemeral data (tokens, stop signals, lock state) not
persistent application config, so cleanup can be aggressive.
"""
__tablename__ = "cache_store"
key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
expires_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)

View File

@@ -1,7 +1,6 @@
import datetime
import uuid
from typing import List
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import HTTPException
@@ -11,10 +10,7 @@ from pydantic import ConfigDict
from sqlalchemy import func
from sqlalchemy.orm import Session
if TYPE_CHECKING:
from starlette.background import BackgroundTasks
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -109,8 +105,8 @@ def upload_files_to_user_files_with_indexing(
user: User,
temp_id_map: dict[str, str] | None,
db_session: Session,
background_tasks: BackgroundTasks | None = None,
) -> CategorizedFilesResult:
# Validate project ownership if a project_id is provided
if project_id is not None and user is not None:
if not check_project_ownership(project_id, user.id, db_session):
raise HTTPException(status_code=404, detail="Project not found")
@@ -131,27 +127,16 @@ def upload_files_to_user_files_with_indexing(
logger.warning(
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
)
if DISABLE_VECTOR_DB and background_tasks is not None:
from onyx.background.task_utils import drain_processing_loop
background_tasks.add_task(drain_processing_loop, tenant_id)
for user_file in user_files:
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
else:
from onyx.background.celery.versioned_apps.client import app as client_app
for user_file in user_files:
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered indexing for user_file_id={user_file.id} "
f"with task_id={task.id}"
)
for user_file in user_files:
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
)
return CategorizedFilesResult(
user_files=user_files,

View File

@@ -4,33 +4,39 @@ import base64
import json
import uuid
from typing import Any
from typing import cast
from typing import Dict
from typing import Optional
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Redis key prefix for OAuth state
OAUTH_STATE_PREFIX = "federated_oauth"
OAUTH_STATE_TTL = 300 # 5 minutes
# Default TTL for OAuth state (5 minutes)
OAUTH_STATE_TTL = 300
class OAuthSession:
"""Represents an OAuth session stored in the cache backend."""
"""Represents an OAuth session stored in Redis."""
def __init__(
self,
federated_connector_id: int,
user_id: str,
redirect_uri: str | None = None,
additional_data: dict[str, Any] | None = None,
redirect_uri: Optional[str] = None,
additional_data: Optional[Dict[str, Any]] = None,
):
self.federated_connector_id = federated_connector_id
self.user_id = user_id
self.redirect_uri = redirect_uri
self.additional_data = additional_data or {}
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for Redis storage."""
return {
"federated_connector_id": self.federated_connector_id,
"user_id": self.user_id,
@@ -39,7 +45,8 @@ class OAuthSession:
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "OAuthSession":
def from_dict(cls, data: Dict[str, Any]) -> "OAuthSession":
"""Create from dictionary retrieved from Redis."""
return cls(
federated_connector_id=data["federated_connector_id"],
user_id=data["user_id"],
@@ -51,27 +58,31 @@ class OAuthSession:
def generate_oauth_state(
federated_connector_id: int,
user_id: str,
redirect_uri: str | None = None,
additional_data: dict[str, Any] | None = None,
redirect_uri: Optional[str] = None,
additional_data: Optional[Dict[str, Any]] = None,
ttl: int = OAUTH_STATE_TTL,
) -> str:
"""
Generate a secure state parameter and store session data in the cache backend.
Generate a secure state parameter and store session data in Redis.
Args:
federated_connector_id: ID of the federated connector
user_id: ID of the user initiating OAuth
redirect_uri: Optional redirect URI after OAuth completion
additional_data: Any additional data to store with the session
ttl: Time-to-live in seconds for the cache key
ttl: Time-to-live in seconds for the Redis key
Returns:
Base64-encoded state parameter
"""
# Generate a random UUID for the state
state_uuid = uuid.uuid4()
state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=")
# Convert UUID to base64 for URL-safe state parameter
state_bytes = state_uuid.bytes
state_b64 = base64.urlsafe_b64encode(state_bytes).decode("utf-8").rstrip("=")
# Create session object
session = OAuthSession(
federated_connector_id=federated_connector_id,
user_id=user_id,
@@ -79,9 +90,15 @@ def generate_oauth_state(
additional_data=additional_data,
)
cache = get_cache_backend()
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl)
# Store in Redis with TTL
redis_client = get_redis_client()
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
redis_client.set(
redis_key,
json.dumps(session.to_dict()),
ex=ttl,
)
logger.info(
f"Generated OAuth state for federated_connector_id={federated_connector_id}, "
@@ -108,15 +125,18 @@ def verify_oauth_state(state: str) -> OAuthSession:
state_bytes = base64.urlsafe_b64decode(padded_state)
state_uuid = uuid.UUID(bytes=state_bytes)
cache = get_cache_backend()
cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
# Look up in Redis
redis_client = get_redis_client()
redis_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}"
session_data = cache.get(cache_key)
session_data = cast(bytes, redis_client.get(redis_key))
if not session_data:
raise ValueError(f"OAuth state not found: {state}")
raise ValueError(f"OAuth state not found in Redis: {state}")
cache.delete(cache_key)
# Delete the key after retrieval (one-time use)
redis_client.delete(redis_key)
# Parse and return session
session_dict = json.loads(session_data)
return OAuthSession.from_dict(session_dict)

View File

@@ -1,11 +1,13 @@
import json
from typing import cast
from onyx.cache.interface import CacheBackend
from redis.client import Redis
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import KVStore
from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@@ -18,27 +20,22 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self, cache: CacheBackend | None = None) -> None:
self._cache = cache
def _get_cache(self) -> CacheBackend:
if self._cache is None:
from onyx.cache.factory import get_cache_backend
self._cache = get_cache_backend()
return self._cache
def __init__(self, redis_client: Redis | None = None) -> None:
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
self.redis_client = get_redis_client()
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Cache backend (typically Redis), but encrypted in Postgres
# Not encrypted in Redis, but encrypted in Postgres
try:
self._get_cache().set(
self.redis_client.set(
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
)
except Exception as e:
# Fallback gracefully to Postgres if Cache backend fails
logger.error(
f"Failed to set value in Cache backend for key '{key}': {str(e)}"
)
# Fallback gracefully to Postgres if Redis fails
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
@@ -56,12 +53,16 @@ class PgRedisKVStore(KeyValueStore):
def load(self, key: str, refresh_cache: bool = False) -> JSON_ro:
if not refresh_cache:
try:
cached = self._get_cache().get(REDIS_KEY_PREFIX + key)
if cached is not None:
return json.loads(cached.decode("utf-8"))
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
if redis_value:
if not isinstance(redis_value, bytes):
raise ValueError(
f"Redis value for key '{key}' is not a bytes object"
)
return json.loads(redis_value.decode("utf-8"))
except Exception as e:
logger.error(
f"Failed to get value from cache for key '{key}': {str(e)}"
f"Failed to get value from Redis for key '{key}': {str(e)}"
)
with get_session_with_current_tenant() as db_session:
@@ -78,21 +79,21 @@ class PgRedisKVStore(KeyValueStore):
value = None
try:
self._get_cache().set(
self.redis_client.set(
REDIS_KEY_PREFIX + key,
json.dumps(value),
ex=KV_REDIS_KEY_EXPIRATION,
)
except Exception as e:
logger.error(f"Failed to set value in cache for key '{key}': {str(e)}")
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
return cast(JSON_ro, value)
def delete(self, key: str) -> None:
try:
self._get_cache().delete(REDIS_KEY_PREFIX + key)
self.redis_client.delete(REDIS_KEY_PREFIX + key)
except Exception as e:
logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}")
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with get_session_with_current_tenant() as db_session:
result = db_session.query(KVStore).filter_by(key=key).delete()

View File

@@ -13,38 +13,44 @@ from datetime import datetime
import httpx
from sqlalchemy.orm import Session
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import sync_auto_mode_models
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
_CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
_CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours
# Redis key for caching the last updated timestamp (per-tenant)
_REDIS_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at"
def _get_cached_last_updated_at() -> datetime | None:
"""Get the cached last_updated_at timestamp from Redis."""
try:
value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT)
if value is not None:
redis_client = get_redis_client()
value = redis_client.get(_REDIS_KEY_LAST_UPDATED_AT)
if value and isinstance(value, bytes):
# Value is bytes, decode to string then parse as ISO format
return datetime.fromisoformat(value.decode("utf-8"))
except Exception as e:
logger.warning(f"Failed to get cached last_updated_at: {e}")
logger.warning(f"Failed to get cached last_updated_at from Redis: {e}")
return None
def _set_cached_last_updated_at(updated_at: datetime) -> None:
"""Set the cached last_updated_at timestamp in Redis."""
try:
get_cache_backend().set(
_CACHE_KEY_LAST_UPDATED_AT,
redis_client = get_redis_client()
# Store as ISO format string, with 24 hour expiration
redis_client.set(
_REDIS_KEY_LAST_UPDATED_AT,
updated_at.isoformat(),
ex=_CACHE_TTL_SECONDS,
ex=60 * 60 * 24, # 24 hours
)
except Exception as e:
logger.warning(f"Failed to set cached last_updated_at: {e}")
logger.warning(f"Failed to set cached last_updated_at in Redis: {e}")
def fetch_llm_recommendations_from_github(
@@ -142,8 +148,9 @@ def sync_llm_models_from_github(
def reset_cache() -> None:
"""Reset the cache timestamp. Useful for testing."""
"""Reset the cache timestamp in Redis. Useful for testing."""
try:
get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT)
redis_client = get_redis_client()
redis_client.delete(_REDIS_KEY_LAST_UPDATED_AT)
except Exception as e:
logger.warning(f"Failed to reset cache: {e}")
logger.warning(f"Failed to reset cache in Redis: {e}")

View File

@@ -32,13 +32,11 @@ from onyx.auth.schemas import UserUpdate
from onyx.auth.users import auth_backend
from onyx.auth.users import create_onyx_oauth_router
from onyx.auth.users import fastapi_users
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.configs.app_configs import APP_HOST
from onyx.configs.app_configs import APP_PORT
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
from onyx.configs.app_configs import OAUTH_CLIENT_ID
@@ -257,20 +255,6 @@ def include_auth_router_with_prefix(
)
def validate_cache_backend_settings() -> None:
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
The Postgres cache backend eliminates the Redis dependency, but only works
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
"""
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
raise RuntimeError(
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
"The Postgres cache backend is only supported in no-vector-DB "
"deployments where Celery is replaced by the in-process task runner."
)
def validate_no_vector_db_settings() -> None:
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
@@ -302,7 +286,6 @@ def validate_no_vector_db_settings() -> None:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
validate_no_vector_db_settings()
validate_cache_backend_settings()
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
@@ -372,20 +355,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
if AUTH_RATE_LIMITING_ENABLED:
await setup_auth_limiter()
if DISABLE_VECTOR_DB:
from onyx.background.periodic_poller import recover_stuck_user_files
from onyx.background.periodic_poller import start_periodic_poller
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
yield
if DISABLE_VECTOR_DB:
from onyx.background.periodic_poller import stop_periodic_poller
stop_periodic_poller()
SqlEngine.reset_engine()
if AUTH_RATE_LIMITING_ENABLED:

View File

@@ -2,7 +2,6 @@ import json
from uuid import UUID
from fastapi import APIRouter
from fastapi import BackgroundTasks
from fastapi import Depends
from fastapi import File
from fastapi import Form
@@ -13,7 +12,13 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.background.celery.tasks.user_file_processing.tasks import (
enqueue_user_file_project_sync_task,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
get_user_file_project_sync_queue_depth,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -29,6 +34,7 @@ from onyx.db.models import UserProject
from onyx.db.persona import get_personas_by_ids
from onyx.db.projects import get_project_token_count
from onyx.db.projects import upload_files_to_user_files_with_indexing
from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.projects.models import CategorizedFilesSnapshot
from onyx.server.features.projects.models import ChatSessionRequest
from onyx.server.features.projects.models import TokenCountResponse
@@ -49,27 +55,7 @@ class UserFileDeleteResult(BaseModel):
assistant_names: list[str] = []
def _trigger_user_file_project_sync(
user_file_id: UUID,
tenant_id: str,
background_tasks: BackgroundTasks | None = None,
) -> None:
if DISABLE_VECTOR_DB and background_tasks is not None:
from onyx.background.task_utils import drain_project_sync_loop
background_tasks.add_task(drain_project_sync_loop, tenant_id)
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
return
from onyx.background.celery.tasks.user_file_processing.tasks import (
enqueue_user_file_project_sync_task,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
get_user_file_project_sync_queue_depth,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.redis.redis_pool import get_redis_client
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
queue_depth = get_user_file_project_sync_queue_depth(client_app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
logger.warning(
@@ -125,7 +111,6 @@ def create_project(
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
def upload_user_files(
bg_tasks: BackgroundTasks,
files: list[UploadFile] = File(...),
project_id: int | None = Form(None),
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
@@ -152,12 +137,12 @@ def upload_user_files(
user=user,
temp_id_map=parsed_temp_id_map,
db_session=db_session,
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
)
return CategorizedFilesSnapshot.from_result(categorized_files_result)
except Exception as e:
# Log error with type, message, and stack for easier debugging
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
raise HTTPException(
status_code=500,
@@ -207,7 +192,6 @@ def get_files_in_project(
def unlink_user_file_from_project(
project_id: int,
file_id: UUID,
bg_tasks: BackgroundTasks,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Response:
@@ -224,6 +208,7 @@ def unlink_user_file_from_project(
if project is None:
raise HTTPException(status_code=404, detail="Project not found")
user_id = user.id
user_file = (
db_session.query(UserFile)
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
@@ -239,7 +224,7 @@ def unlink_user_file_from_project(
db_session.commit()
tenant_id = get_current_tenant_id()
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
_trigger_user_file_project_sync(user_file.id, tenant_id)
return Response(status_code=204)
@@ -252,7 +237,6 @@ def unlink_user_file_from_project(
def link_user_file_to_project(
project_id: int,
file_id: UUID,
bg_tasks: BackgroundTasks,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> UserFileSnapshot:
@@ -284,7 +268,7 @@ def link_user_file_to_project(
db_session.commit()
tenant_id = get_current_tenant_id()
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
_trigger_user_file_project_sync(user_file.id, tenant_id)
return UserFileSnapshot.from_model(user_file)
@@ -442,7 +426,6 @@ def delete_project(
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
def delete_user_file(
file_id: UUID,
bg_tasks: BackgroundTasks,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> UserFileDeleteResult:
@@ -475,25 +458,15 @@ def delete_user_file(
db_session.commit()
tenant_id = get_current_tenant_id()
if DISABLE_VECTOR_DB:
from onyx.background.task_utils import drain_delete_loop
bg_tasks.add_task(drain_delete_loop, tenant_id)
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
else:
from onyx.background.celery.versioned_apps.client import app as client_app
task = client_app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered delete for user_file_id={user_file.id} "
f"with task_id={task.id}"
)
task = client_app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
)
return UserFileDeleteResult(
has_associations=False, project_names=[], assistant_names=[]
)

View File

@@ -8,10 +8,10 @@ import httpx
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.cache.factory import get_shared_cache_backend
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.release_notes import create_release_notifications_for_versions
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
@@ -113,46 +113,60 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
def get_cached_etag() -> str | None:
cache = get_shared_cache_backend()
"""Get the cached GitHub ETag from Redis."""
redis_client = get_shared_redis_client()
try:
etag = cache.get(REDIS_KEY_ETAG)
etag = redis_client.get(REDIS_KEY_ETAG)
if etag:
return etag.decode("utf-8")
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
return None
except Exception as e:
logger.error(f"Failed to get cached etag: {e}")
logger.error(f"Failed to get cached etag from Redis: {e}")
return None
def get_last_fetch_time() -> datetime | None:
cache = get_shared_cache_backend()
"""Get the last fetch timestamp from Redis."""
redis_client = get_shared_redis_client()
try:
raw = cache.get(REDIS_KEY_FETCHED_AT)
if not raw:
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
if not fetched_at_str:
return None
last_fetch = datetime.fromisoformat(raw.decode("utf-8"))
decoded = (
fetched_at_str.decode("utf-8")
if isinstance(fetched_at_str, bytes)
else str(fetched_at_str)
)
last_fetch = datetime.fromisoformat(decoded)
# Defensively ensure timezone awareness
# fromisoformat() returns naive datetime if input lacks timezone
if last_fetch.tzinfo is None:
# Assume UTC for naive datetimes
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
else:
# Convert to UTC if timezone-aware
last_fetch = last_fetch.astimezone(timezone.utc)
return last_fetch
except Exception as e:
logger.error(f"Failed to get last fetch time from cache: {e}")
logger.error(f"Failed to get last fetch time from Redis: {e}")
return None
def save_fetch_metadata(etag: str | None) -> None:
cache = get_shared_cache_backend()
"""Save ETag and fetch timestamp to Redis."""
redis_client = get_shared_redis_client()
now = datetime.now(timezone.utc)
try:
cache.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
if etag:
cache.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
except Exception as e:
logger.error(f"Failed to save fetch metadata to cache: {e}")
logger.error(f"Failed to save fetch metadata to Redis: {e}")
def is_cache_stale() -> bool:
@@ -182,10 +196,11 @@ def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
if not is_cache_stale():
return
cache = get_shared_cache_backend()
lock = cache.lock(
# Acquire lock to prevent concurrent fetches
redis_client = get_shared_redis_client()
lock = redis_client.lock(
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
timeout=90,
timeout=90, # 90 second timeout for the lock
)
# Non-blocking acquire - if we can't get the lock, another request is handling it

View File

@@ -97,7 +97,6 @@ def _build_llm_provider_request(
), # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
deployment_name=deployment_name, # From request
is_public=True,
groups=[],
@@ -136,7 +135,6 @@ def _build_llm_provider_request(
api_key=api_key,
api_base=api_base,
api_version=api_version,
default_model_name=model_name,
deployment_name=deployment_name,
is_public=True,
groups=[],
@@ -168,7 +166,6 @@ def _create_image_gen_llm_provider__no_commit(
api_key=provider_request.api_key,
api_base=provider_request.api_base,
api_version=provider_request.api_version,
default_model_name=provider_request.default_model_name,
deployment_name=provider_request.deployment_name,
is_public=provider_request.is_public,
custom_config=provider_request.custom_config,

View File

@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
@@ -52,11 +55,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
@@ -233,12 +237,9 @@ def test_llm_configuration(
test_api_key = test_llm_request.api_key
test_custom_config = test_llm_request.custom_config
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
name=test_llm_request.name, db_session=db_session
if test_llm_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=test_llm_request.id, db_session=db_session
)
if existing_provider:
test_custom_config = _restore_masked_custom_config_values(
@@ -268,7 +269,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
model=test_llm_request.model,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -303,7 +304,7 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderView]:
) -> LLMProviderResponse[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
@@ -328,7 +329,15 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return llm_provider_list
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
@admin_router.put("/provider")
@@ -344,18 +353,44 @@ def put_llm_provider(
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_existing_llm_provider(
existing_provider = None
if llm_provider_upsert_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
# Check name constraints
# TODO: Once port from name to id is complete, unique name will no longer be required
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
raise HTTPException(
status_code=400,
detail="Renaming providers is not currently supported",
)
found_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if found_provider is not None and found_provider is not existing_provider:
raise HTTPException(
status_code=400,
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists"
),
)
elif not existing_provider and not is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist"
),
)
# SSRF Protection: Validate api_base and custom_config match stored values
@@ -393,22 +428,6 @@ def put_llm_provider(
deduplicated_personas.append(persona_id)
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
# TODO: Remove this logic on api change
# Believed to be a dead pathway but we want to be safe for now
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
@@ -438,8 +457,8 @@ def put_llm_provider(
config = fetch_llm_recommendations_from_github()
if config and llm_provider_upsert_request.provider in config.providers:
# Refetch the provider to get the updated model
updated_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
updated_provider = fetch_existing_llm_provider_by_id(
id=result.id, db_session=db_session
)
if updated_provider:
sync_auto_mode_models(
@@ -469,28 +488,29 @@ def delete_llm_provider(
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/provider/{provider_id}/default")
@admin_router.post("/default")
def set_provider_as_default(
provider_id: int,
default_model_request: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(provider_id=provider_id, db_session=db_session)
update_default_provider(
provider_id=default_model_request.provider_id,
model_name=default_model_request.model_name,
db_session=db_session,
)
@admin_router.post("/provider/{provider_id}/default-vision")
@admin_router.post("/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
default_model: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if vision_model is None:
raise HTTPException(status_code=404, detail="Vision model not provided")
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
provider_id=default_model.provider_id,
vision_model=default_model.model_name,
db_session=db_session,
)
@@ -516,7 +536,7 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
) -> LLMProviderResponse[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
@@ -545,7 +565,13 @@ def get_vision_capable_providers(
]
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
return vision_provider_response
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
"""Endpoints for all"""
@@ -555,7 +581,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -592,7 +618,15 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return accessible_providers
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
def get_valid_model_names_for_persona(
@@ -635,7 +669,7 @@ def list_llm_providers_for_persona(
persona_id: int,
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
@@ -682,7 +716,51 @@ def list_llm_providers_for_persona(
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
)
return llm_provider_list
# Get the default model and vision model for the persona
# TODO: Port persona's over to use ID
persona_default_provider = persona.llm_model_provider_override
persona_default_model = persona.llm_model_version_override
default_text_model = fetch_default_llm_model(db_session)
default_vision_model = fetch_default_vision_model(db_session)
# Build default_text and default_vision using persona overrides when available,
# falling back to the global defaults.
default_text = DefaultModel.from_model_config(default_text_model)
default_vision = DefaultModel.from_model_config(default_vision_model)
if persona_default_provider:
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
if provider and can_user_access_llm_provider(
provider, user_group_ids, persona, is_admin=is_admin
):
if persona_default_model:
# Persona specifies both provider and model — use them directly
default_text = DefaultModel(
provider_id=provider.id,
model_name=persona_default_model,
)
else:
# Persona specifies only the provider — pick a visible (public) model,
# falling back to any model on this provider
visible_model = next(
(mc for mc in provider.model_configurations if mc.is_visible),
None,
)
fallback_model = visible_model or next(
iter(provider.model_configurations), None
)
if fallback_model:
default_text = DefaultModel(
provider_id=provider.id,
model_name=fallback_model.name,
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=llm_provider_list,
default_text=default_text,
default_vision=default_vision,
)
@admin_router.get("/provider-contextual-cost")

View File

@@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
@@ -21,50 +25,22 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
# TODO: Clear this up on api refactor
# There is still logic that requires sending each providers default model name
# There is no logic that requires sending the providers default vision model name
# We only send for the one that is actually the default
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
"""Find the default conversation model name for a provider.
Returns the model name if found, otherwise returns empty string.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
return model_config.name
return ""
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
"""Find the default vision model name for a provider.
Returns the model name if found, otherwise returns None.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
return model_config.name
return None
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
id: int | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -80,13 +56,10 @@ class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
id: int
name: str
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -99,24 +72,12 @@ class LLMProviderDescriptor(BaseModel):
)
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = (
default_model_name or llm_provider_model.default_model_name or ""
)
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations, provider
),
@@ -130,18 +91,17 @@ class LLMProvider(BaseModel):
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
is_public: bool = True
is_auto_mode: bool = False
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
id: int | None = None
api_key_changed: bool = False
custom_config_changed: bool = False
model_configurations: list["ModelConfigurationUpsertRequest"] = []
@@ -157,8 +117,6 @@ class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -180,16 +138,6 @@ class LLMProviderView(LLMProvider):
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = (
default_model_name or llm_provider_model.default_model_name or ""
)
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@@ -202,10 +150,6 @@ class LLMProviderView(LLMProvider):
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
is_public=llm_provider_model.is_public,
is_auto_mode=llm_provider_model.is_auto_mode,
groups=groups,
@@ -425,3 +369,38 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
@classmethod
def from_model_config(
cls, model_config: ModelConfigurationModel | None
) -> DefaultModel | None:
if not model_config:
return None
return cls(
provider_id=model_config.llm_provider_id,
model_name=model_config.name,
)
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> LLMProviderResponse[T]:
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -13,13 +13,13 @@ from fastapi import Request
from fastapi import Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.pat import get_hashed_pat_from_request
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_user
from onyx.cache.factory import get_cache_backend
from onyx.chat.chat_processing_checker import is_chat_session_processing
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import convert_chat_history_basic
@@ -67,6 +67,7 @@ from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.redis.redis_pool import get_redis_client
from onyx.secondary_llm_flows.chat_session_naming import generate_chat_session_name
from onyx.server.api_key_usage import check_api_key_usage
from onyx.server.query_and_chat.models import ChatFeedbackRequest
@@ -313,7 +314,7 @@ def get_chat_session(
]
try:
is_processing = is_chat_session_processing(session_id, get_cache_backend())
is_processing = is_chat_session_processing(session_id, get_redis_client())
# Edit the last message to indicate loading (Overriding default message value)
if is_processing and chat_message_details:
last_msg = chat_message_details[-1]
@@ -910,10 +911,11 @@ async def search_chats(
def stop_chat_session(
chat_session_id: UUID,
user: User = Depends(current_user), # noqa: ARG001
redis_client: Redis = Depends(get_redis_client),
) -> dict[str, str]:
"""
Stop a chat session by setting a stop signal.
Stop a chat session by setting a stop signal in Redis.
This endpoint is called by the frontend when the user clicks the stop button.
"""
set_fence(chat_session_id, get_cache_backend(), True)
set_fence(chat_session_id, redis_client, True)
return {"message": "Chat session stopped"}

View File

@@ -19,6 +19,7 @@ class ApplicationStatus(str, Enum):
PAYMENT_REMINDER = "payment_reminder"
GRACE_PERIOD = "grace_period"
GATED_ACCESS = "gated_access"
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
class Notification(BaseModel):
@@ -82,6 +83,10 @@ class Settings(BaseModel):
# Default Assistant settings
disable_default_assistant: bool | None = False
# Seat usage - populated by license enforcement when seat limit is exceeded
seat_count: int | None = None
used_seats: int | None = None
# OpenSearch migration
opensearch_indexing_enabled: bool = False

View File

@@ -1,4 +1,3 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
@@ -7,8 +6,11 @@ from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -31,22 +33,30 @@ def load_settings() -> Settings:
logger.error(f"Error loading settings from KV store: {str(e)}")
settings = Settings()
cache = get_cache_backend()
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
try:
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is not None:
assert isinstance(value, bytes)
anonymous_user_enabled = int(value.decode("utf-8")) == 1
else:
# Default to False
anonymous_user_enabled = False
cache.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL)
# Optionally store the default back to Redis
redis_client.set(
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
)
except Exception as e:
logger.error(f"Error loading anonymous user setting from cache: {str(e)}")
# Log the error and reset to default
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
anonymous_user_enabled = False
settings.anonymous_user_enabled = anonymous_user_enabled
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
# Override user knowledge setting if disabled via environment variable
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
@@ -56,10 +66,11 @@ def load_settings() -> Settings:
def store_settings(settings: Settings) -> None:
cache = get_cache_backend()
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
redis_client = get_redis_client(tenant_id=tenant_id)
if settings.anonymous_user_enabled is not None:
cache.set(
redis_client.set(
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
"1" if settings.anonymous_user_enabled else "0",
ex=SETTINGS_TTL,

View File

@@ -25,6 +25,7 @@ from onyx.db.enums import EmbeddingPrecision
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.search_settings import get_active_search_settings
@@ -254,14 +255,18 @@ def setup_postgres(db_session: Session) -> None:
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
provider_name = "DevEnvPresetOpenAI"
existing = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
model_req = LLMProviderUpsertRequest(
name="DevEnvPresetOpenAI",
id=existing.id if existing else None,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=GEN_AI_API_KEY,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
is_public=True,
groups=[],
model_configurations=[
@@ -273,7 +278,9 @@ def setup_postgres(db_session: Session) -> None:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
update_default_provider(
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
)
def update_default_multipass_indexing(db_session: Session) -> None:

View File

@@ -257,7 +257,7 @@ exceptiongroup==1.3.0
# via
# braintrust
# fastmcp
fastapi==0.128.0
fastapi==0.133.1
# via
# fastapi-limiter
# fastapi-users
@@ -1155,6 +1155,7 @@ typing-inspect==0.9.0
# via dataclasses-json
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings
@@ -1216,7 +1217,7 @@ websockets==15.0.1
# via
# fastmcp
# google-genai
werkzeug==3.1.5
werkzeug==3.1.6
# via sendgrid
wrapt==1.17.3
# via

View File

@@ -125,7 +125,7 @@ executing==2.2.1
# via stack-data
faker==40.1.2
# via onyx
fastapi==0.128.0
fastapi==0.133.1
# via
# onyx
# onyx-devtools
@@ -619,6 +619,7 @@ typing-extensions==4.15.0
# typing-inspection
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings

View File

@@ -90,7 +90,7 @@ docstring-parser==0.17.0
# via google-cloud-aiplatform
durationpy==0.10
# via kubernetes
fastapi==0.128.0
fastapi==0.133.1
# via onyx
fastavro==1.12.1
# via cohere
@@ -398,6 +398,7 @@ typing-extensions==4.15.0
# typing-inspection
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings

View File

@@ -108,7 +108,7 @@ durationpy==0.10
# via kubernetes
einops==0.8.1
# via onyx
fastapi==0.128.0
fastapi==0.133.1
# via
# onyx
# sentry-sdk
@@ -525,6 +525,7 @@ typing-extensions==4.15.0
# typing-inspection
typing-inspection==0.4.2
# via
# fastapi
# mcp
# pydantic
# pydantic-settings

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,

View File

@@ -28,7 +28,6 @@ def ensure_default_llm_provider(db_session: Session) -> None:
provider=LlmProviderNames.OPENAI,
api_key=os.environ.get("OPENAI_API_KEY", "test"),
is_public=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini",
@@ -41,7 +40,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
llm_provider_upsert_request=llm_provider_request,
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
# Rollback to clear the pending transaction state
db_session.rollback()

View File

@@ -47,7 +47,6 @@ def test_answer_with_only_anthropic_provider(
name=provider_name,
provider=LlmProviderNames.ANTHROPIC,
api_key=anthropic_api_key,
default_model_name=anthropic_model,
is_public=True,
groups=[],
model_configurations=[
@@ -59,7 +58,7 @@ def test_answer_with_only_anthropic_provider(
)
try:
update_default_provider(anthropic_provider.id, db_session)
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
test_user = create_test_user(db_session, email_prefix="anthropic_only")
chat_session = create_chat_session(

View File

@@ -1,219 +0,0 @@
"""External dependency unit tests for startup recovery (Step 10g).
Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING,
needs_project_sync) then calls ``recover_stuck_user_files`` and verifies
the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``.
Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures).
The per-file ``_impl`` functions are mocked so no real file store or
connector is needed — we only verify that recovery finds and dispatches
the correct files.
"""
from collections.abc import Generator
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.background.periodic_poller import recover_stuck_user_files
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
def _create_user_file(
db_session: Session,
user_id: object,
*,
status: UserFileStatus = UserFileStatus.PROCESSING,
needs_project_sync: bool = False,
needs_persona_sync: bool = False,
) -> UserFile:
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=status,
needs_project_sync=needs_project_sync,
needs_persona_sync=needs_persona_sync,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@pytest.fixture()
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
"""Track created UserFile rows and delete them after each test."""
created: list[UserFile] = []
yield created
for uf in created:
existing = db_session.get(UserFile, uf.id)
if existing:
db_session.delete(existing)
db_session.commit()
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestRecoverProcessingFiles:
"""Files in PROCESSING status are re-processed via the processing drain loop."""
def test_processing_files_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_proc")
uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}._process_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) in called_ids
), f"Expected file {uf.id} to be recovered but got: {called_ids}"
def test_completed_files_not_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_comp")
uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}._process_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) not in called_ids
), f"COMPLETED file {uf.id} should not have been recovered"
class TestRecoverDeletingFiles:
"""Files in DELETING status are recovered via the delete drain loop."""
def test_deleting_files_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_del")
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}._delete_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) in called_ids
), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}"
class TestRecoverSyncFiles:
"""Files needing project/persona sync are recovered via the sync drain loop."""
def test_needs_project_sync_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_sync")
uf = _create_user_file(
db_session,
user.id,
status=UserFileStatus.COMPLETED,
needs_project_sync=True,
)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}._project_sync_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) in called_ids
), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}"
def test_needs_persona_sync_recovered(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_psync")
uf = _create_user_file(
db_session,
user.id,
status=UserFileStatus.COMPLETED,
needs_persona_sync=True,
)
_cleanup_user_files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}._project_sync_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
assert (
str(uf.id) in called_ids
), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}"
class TestRecoveryMultipleFiles:
"""Recovery processes all stuck files in one pass, not just the first."""
def test_multiple_processing_files(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
_cleanup_user_files: list[UserFile],
) -> None:
user = create_test_user(db_session, "recovery_multi")
files = []
for _ in range(3):
uf = _create_user_file(
db_session, user.id, status=UserFileStatus.PROCESSING
)
_cleanup_user_files.append(uf)
files.append(uf)
mock_impl = MagicMock()
with patch(f"{_IMPL_MODULE}._process_user_file_impl", mock_impl):
recover_stuck_user_files(TEST_TENANT_ID)
called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list}
expected_ids = {str(uf.id) for uf in files}
assert expected_ids.issubset(called_ids), (
f"Expected all {len(files)} files to be recovered. "
f"Missing: {expected_ids - called_ids}"
)

View File

@@ -1,64 +0,0 @@
"""Fixtures for cache backend tests.
Requires a running PostgreSQL instance (and Redis for parity tests).
Run with::
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
"""
from collections.abc import Generator
import pytest
from sqlalchemy import delete
from onyx.cache.interface import CacheBackend
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.cache.redis_backend import RedisCacheBackend
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.models import CacheStore
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(scope="session", autouse=True)
def _init_db() -> Generator[None, None, None]:
SqlEngine.init_engine(pool_size=5, max_overflow=2)
CacheStore.__table__.create(get_sqlalchemy_engine(), checkfirst=True)
yield
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
session.execute(delete(CacheStore))
session.commit()
@pytest.fixture(autouse=True)
def _tenant_context() -> Generator[None, None, None]:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture
def pg_cache() -> PostgresCacheBackend:
return PostgresCacheBackend(TEST_TENANT_ID)
@pytest.fixture
def redis_cache() -> RedisCacheBackend:
from onyx.redis.redis_pool import redis_pool
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
def cache(
request: pytest.FixtureRequest,
pg_cache: PostgresCacheBackend,
redis_cache: RedisCacheBackend,
) -> CacheBackend:
if request.param == "postgres":
return pg_cache
return redis_cache

View File

@@ -1,98 +0,0 @@
"""Parameterized tests that run the same CacheBackend operations against
both Redis and PostgreSQL, asserting identical return values.
Each test runs twice (once per backend) via the ``cache`` fixture defined
in conftest.py.
"""
import time
from uuid import uuid4
from onyx.cache.interface import CacheBackend
def _key() -> str:
return f"parity_{uuid4().hex[:12]}"
class TestKVParity:
def test_get_missing(self, cache: CacheBackend) -> None:
assert cache.get(_key()) is None
def test_get_set(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"value")
assert cache.get(k) == b"value"
def test_overwrite(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"a")
cache.set(k, b"b")
assert cache.get(k) == b"b"
def test_set_string(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, "hello")
assert cache.get(k) == b"hello"
def test_set_int(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, 42)
assert cache.get(k) == b"42"
def test_delete(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x")
cache.delete(k)
assert cache.get(k) is None
def test_exists(self, cache: CacheBackend) -> None:
k = _key()
assert not cache.exists(k)
cache.set(k, b"x")
assert cache.exists(k)
class TestTTLParity:
def test_ttl_missing(self, cache: CacheBackend) -> None:
assert cache.ttl(_key()) == -2
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x")
assert cache.ttl(k) == -1
def test_ttl_remaining(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x", ex=10)
remaining = cache.ttl(k)
assert 8 <= remaining <= 10
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
k = _key()
cache.set(k, b"x", ex=1)
assert cache.get(k) == b"x"
time.sleep(1.5)
assert cache.get(k) is None
class TestLockParity:
def test_acquire_release(self, cache: CacheBackend) -> None:
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
assert lock.acquire(blocking=False)
assert lock.owned()
lock.release()
assert not lock.owned()
class TestListParity:
def test_rpush_blpop(self, cache: CacheBackend) -> None:
k = f"parity_list_{uuid4().hex[:8]}"
cache.rpush(k, b"item")
result = cache.blpop([k], timeout=1)
assert result is not None
assert result[1] == b"item"
def test_blpop_timeout(self, cache: CacheBackend) -> None:
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
assert result is None

View File

@@ -1,161 +0,0 @@
"""Tests for PgRedisKVStore's cache layer integration with CacheBackend.
Verifies that the KV store correctly uses the CacheBackend for caching
in front of PostgreSQL: cache hits, cache misses falling through to PG,
cache population after PG reads, cache invalidation on delete, and
graceful degradation when the cache backend raises.
Requires running PostgreSQL.
"""
import json
from collections.abc import Generator
from unittest.mock import MagicMock
import pytest
from sqlalchemy import delete
from onyx.cache.interface import CacheBackend
from onyx.cache.postgres_backend import PostgresCacheBackend
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.models import CacheStore
from onyx.db.models import KVStore
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.key_value_store.store import PgRedisKVStore
from onyx.key_value_store.store import REDIS_KEY_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(scope="session", autouse=True)
def _init_db() -> Generator[None, None, None]:
SqlEngine.init_engine(pool_size=5, max_overflow=2)
engine = get_sqlalchemy_engine()
CacheStore.__table__.create(engine, checkfirst=True)
KVStore.__table__.create(engine, checkfirst=True)
yield
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
session.execute(delete(CacheStore))
session.execute(delete(KVStore))
session.commit()
@pytest.fixture(autouse=True)
def _tenant_context() -> Generator[None, None, None]:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture(autouse=True)
def _clean_kv(
_tenant_context: None,
) -> Generator[None, None, None]:
yield
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
session.execute(delete(KVStore))
session.execute(delete(CacheStore))
session.commit()
@pytest.fixture
def pg_cache() -> PostgresCacheBackend:
return PostgresCacheBackend(TEST_TENANT_ID)
@pytest.fixture
def kv_store(pg_cache: PostgresCacheBackend) -> PgRedisKVStore:
return PgRedisKVStore(cache=pg_cache)
class TestStoreAndLoad:
def test_store_populates_cache_and_pg(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k1", {"hello": "world"})
cached = pg_cache.get(REDIS_KEY_PREFIX + "k1")
assert cached is not None
assert json.loads(cached) == {"hello": "world"}
loaded = kv_store.load("k1")
assert loaded == {"hello": "world"}
def test_load_returns_cached_value_without_pg_hit(
self, pg_cache: PostgresCacheBackend
) -> None:
"""If the cache already has the value, PG should not be queried."""
pg_cache.set(REDIS_KEY_PREFIX + "cached_only", json.dumps({"from": "cache"}))
kv = PgRedisKVStore(cache=pg_cache)
assert kv.load("cached_only") == {"from": "cache"}
def test_load_falls_through_to_pg_on_cache_miss(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k2", [1, 2, 3])
pg_cache.delete(REDIS_KEY_PREFIX + "k2")
assert pg_cache.get(REDIS_KEY_PREFIX + "k2") is None
loaded = kv_store.load("k2")
assert loaded == [1, 2, 3]
repopulated = pg_cache.get(REDIS_KEY_PREFIX + "k2")
assert repopulated is not None
assert json.loads(repopulated) == [1, 2, 3]
def test_load_with_refresh_cache_skips_cache(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("k3", "original")
pg_cache.set(REDIS_KEY_PREFIX + "k3", json.dumps("stale"))
loaded = kv_store.load("k3", refresh_cache=True)
assert loaded == "original"
class TestDelete:
def test_delete_removes_from_cache_and_pg(
self, kv_store: PgRedisKVStore, pg_cache: PostgresCacheBackend
) -> None:
kv_store.store("del_me", "bye")
kv_store.delete("del_me")
assert pg_cache.get(REDIS_KEY_PREFIX + "del_me") is None
with pytest.raises(KvKeyNotFoundError):
kv_store.load("del_me")
def test_delete_missing_key_raises(self, kv_store: PgRedisKVStore) -> None:
with pytest.raises(KvKeyNotFoundError):
kv_store.delete("nonexistent")
class TestCacheFailureGracefulDegradation:
def test_store_succeeds_when_cache_set_raises(self) -> None:
failing_cache = MagicMock(spec=CacheBackend)
failing_cache.set.side_effect = ConnectionError("cache down")
kv = PgRedisKVStore(cache=failing_cache)
kv.store("resilient", {"data": True})
working_cache = MagicMock(spec=CacheBackend)
working_cache.get.return_value = None
kv_reader = PgRedisKVStore(cache=working_cache)
loaded = kv_reader.load("resilient")
assert loaded == {"data": True}
def test_load_falls_through_when_cache_get_raises(self) -> None:
failing_cache = MagicMock(spec=CacheBackend)
failing_cache.get.side_effect = ConnectionError("cache down")
failing_cache.set.side_effect = ConnectionError("cache down")
kv = PgRedisKVStore(cache=failing_cache)
kv.store("survive", 42)
loaded = kv.load("survive")
assert loaded == 42

View File

@@ -1,218 +0,0 @@
"""Tests for PostgresCacheBackend against real PostgreSQL.
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
locks (acquire / release / contention), list operations (rpush / blpop),
and the periodic cleanup function.
"""
import time
from uuid import uuid4
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
from onyx.cache.postgres_backend import PostgresCacheBackend
def _key() -> str:
return f"test_{uuid4().hex[:12]}"
# ------------------------------------------------------------------
# Basic KV
# ------------------------------------------------------------------
class TestKV:
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"hello")
assert pg_cache.get(k) == b"hello"
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
assert pg_cache.get(_key()) is None
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"first")
pg_cache.set(k, b"second")
assert pg_cache.get(k) == b"second"
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, "string_val")
assert pg_cache.get(k) == b"string_val"
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, 42)
assert pg_cache.get(k) == b"42"
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"to_delete")
pg_cache.delete(k)
assert pg_cache.get(k) is None
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
pg_cache.delete(_key())
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
assert not pg_cache.exists(k)
pg_cache.set(k, b"x")
assert pg_cache.exists(k)
# ------------------------------------------------------------------
# TTL
# ------------------------------------------------------------------
class TestTTL:
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"ephemeral", ex=1)
assert pg_cache.get(k) == b"ephemeral"
time.sleep(1.5)
assert pg_cache.get(k) is None
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"forever")
assert pg_cache.ttl(k) == -1
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
assert pg_cache.ttl(_key()) == -2
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=10)
remaining = pg_cache.ttl(k)
assert 8 <= remaining <= 10
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=1)
time.sleep(1.5)
assert pg_cache.ttl(k) == -2
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x")
assert pg_cache.ttl(k) == -1
pg_cache.expire(k, 10)
assert 8 <= pg_cache.ttl(k) <= 10
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"x", ex=1)
assert pg_cache.exists(k)
time.sleep(1.5)
assert not pg_cache.exists(k)
# ------------------------------------------------------------------
# Locks
# ------------------------------------------------------------------
class TestLock:
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
assert lock.acquire(blocking=False)
assert lock.owned()
lock.release()
assert not lock.owned()
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
name = f"contention_{uuid4().hex[:8]}"
lock1 = pg_cache.lock(name)
lock2 = pg_cache.lock(name)
assert lock1.acquire(blocking=False)
assert not lock2.acquire(blocking=False)
lock1.release()
assert lock2.acquire(blocking=False)
lock2.release()
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
assert lock.owned()
assert not lock.owned()
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
name = f"timeout_{uuid4().hex[:8]}"
holder = pg_cache.lock(name)
holder.acquire(blocking=False)
waiter = pg_cache.lock(name, timeout=0.3)
start = time.monotonic()
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
elapsed = time.monotonic() - start
assert elapsed >= 0.25
holder.release()
# ------------------------------------------------------------------
# List (rpush / blpop)
# ------------------------------------------------------------------
class TestList:
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
k = f"list_{uuid4().hex[:8]}"
pg_cache.rpush(k, b"item1")
result = pg_cache.blpop([k], timeout=1)
assert result is not None
assert result == (k.encode(), b"item1")
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
assert result is None
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
k = f"fifo_{uuid4().hex[:8]}"
pg_cache.rpush(k, b"first")
time.sleep(0.01)
pg_cache.rpush(k, b"second")
r1 = pg_cache.blpop([k], timeout=1)
r2 = pg_cache.blpop([k], timeout=1)
assert r1 is not None and r1[1] == b"first"
assert r2 is not None and r2[1] == b"second"
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
k1 = f"mk1_{uuid4().hex[:8]}"
k2 = f"mk2_{uuid4().hex[:8]}"
pg_cache.rpush(k2, b"from_k2")
result = pg_cache.blpop([k1, k2], timeout=1)
assert result is not None
assert result == (k2.encode(), b"from_k2")
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
class TestCleanup:
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"stale", ex=1)
time.sleep(1.5)
cleanup_expired_cache_entries()
assert pg_cache.ttl(k) == -2
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"fresh", ex=300)
cleanup_expired_cache_entries()
assert pg_cache.get(k) == b"fresh"
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
k = _key()
pg_cache.set(k, b"permanent")
cleanup_expired_cache_entries()
assert pg_cache.get(k) == b"permanent"

View File

@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
test_llm_configuration as run_test_llm_configuration,
)
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
@@ -44,15 +45,14 @@ def _create_test_provider(
db_session: Session,
name: str,
api_key: str = "sk-test-key-00000000000000000000000000000000000",
) -> None:
) -> LLMProviderView:
"""Helper to create a test LLM provider in the database."""
upsert_llm_provider(
return upsert_llm_provider(
LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=api_key,
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -102,17 +102,11 @@ class TestLLMConfigurationEndpoint:
# This should complete without exception
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None, # New provider (not in DB)
provider=LlmProviderNames.OPENAI,
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -152,17 +146,11 @@ class TestLLMConfigurationEndpoint:
with pytest.raises(HTTPException) as exc_info:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -194,7 +182,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -202,17 +192,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=False - should use stored key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -246,7 +231,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -254,17 +241,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=True - should use new key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -297,7 +279,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database with custom_config
upsert_llm_provider(
provider = upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
@@ -305,7 +287,6 @@ class TestLLMConfigurationEndpoint:
api_key_changed=True,
custom_config=original_custom_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -321,18 +302,13 @@ class TestLLMConfigurationEndpoint:
# Test with custom_config_changed=False - should use stored config
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name,
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -368,17 +344,11 @@ class TestLLMConfigurationEndpoint:
for model_name in test_models:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
model=model_name,
),
_=_create_mock_admin(),
db_session=db_session,
@@ -442,7 +412,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_initial_model,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -452,7 +421,7 @@ class TestDefaultProviderEndpoint:
)
# Set provider 1 as the default provider explicitly
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
# Step 2: Call run_test_default_provider - should use provider 1's default model
with patch(
@@ -472,7 +441,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_2_api_key,
api_key_changed=True,
default_model_name=provider_2_default_model,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -499,11 +467,11 @@ class TestDefaultProviderEndpoint:
# Step 5: Update provider 1's default model
upsert_llm_provider(
LLMProviderUpsertRequest(
id=provider_1.id,
name=provider_1_name,
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_updated_model, # Changed
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -512,6 +480,9 @@ class TestDefaultProviderEndpoint:
db_session=db_session,
)
# Set provider 1's default model to the updated model
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
# Step 6: Call run_test_default_provider - should use new model on provider 1
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -524,7 +495,7 @@ class TestDefaultProviderEndpoint:
captured_llms.clear()
# Step 7: Change the default provider to provider 2
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 8: Call run_test_default_provider - should use provider 2
with patch(
@@ -596,7 +567,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -605,7 +575,7 @@ class TestDefaultProviderEndpoint:
),
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
# Test should fail
with patch(

View File

@@ -49,7 +49,6 @@ def _create_test_provider(
api_key_changed=True,
api_base=api_base,
custom_config=custom_config,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -91,14 +90,14 @@ class TestLLMProviderChanges:
the API key should be blocked.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://attacker.example.com",
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -125,16 +124,16 @@ class TestLLMProviderChanges:
Changing api_base IS allowed when the API key is also being changed.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom-endpoint.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -159,14 +158,16 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
_create_test_provider(db_session, provider_name, api_base=original_api_base)
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=original_api_base,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -190,14 +191,14 @@ class TestLLMProviderChanges:
changes. This allows model-only updates when provider has no custom base URL.
"""
try:
_create_test_provider(db_session, provider_name, api_base=None)
view = _create_test_provider(db_session, provider_name, api_base=None)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=view.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -223,14 +224,16 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
_create_test_provider(db_session, provider_name, api_base=original_api_base)
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=None,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -259,14 +262,14 @@ class TestLLMProviderChanges:
users have full control over their deployment.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -297,7 +300,6 @@ class TestLLMProviderChanges:
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -322,7 +324,7 @@ class TestLLMProviderChanges:
redirect LLM API requests).
"""
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"SOME_CONFIG": "original_value"},
@@ -330,11 +332,11 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -362,15 +364,15 @@ class TestLLMProviderChanges:
without changing the API key.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -399,7 +401,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "us-west-2"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -407,13 +409,13 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=True,
custom_config=new_config,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -438,17 +440,17 @@ class TestLLMProviderChanges:
original_config = {"AWS_REGION_NAME": "us-east-1"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session, provider_name, custom_config=original_config
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=original_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -474,7 +476,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "eu-west-1"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -482,10 +484,10 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=new_config,
default_model_name="gpt-4o-mini",
custom_config_changed=True,
)
@@ -530,14 +532,8 @@ def test_upload_with_custom_config_then_change(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config=custom_config,
@@ -546,11 +542,10 @@ def test_upload_with_custom_config_then_change(
db_session=db_session,
)
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
custom_config=custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -569,14 +564,9 @@ def test_upload_with_custom_config_then_change(
# Turn auto mode off
run_llm_config_test(
LLMTestRequest(
name=name,
id=provider.id,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=False,
),
@@ -586,9 +576,9 @@ def test_upload_with_custom_config_then_change(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
@@ -616,13 +606,13 @@ def test_upload_with_custom_config_then_change(
)
# Check inside the database and check that custom_config is the same as the original
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not provider:
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not db_provider:
assert False, "Provider not found in the database"
assert provider.custom_config == custom_config, (
assert db_provider.custom_config == custom_config, (
f"Expected custom_config {custom_config}, "
f"but got {provider.custom_config}"
f"but got {db_provider.custom_config}"
)
finally:
db_session.rollback()
@@ -642,11 +632,10 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
}
try:
put_llm_provider(
view = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -665,9 +654,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=view.id,
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config={
"vertex_credentials": _mask_string(
original_custom_config["vertex_credentials"]
@@ -706,7 +695,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
) -> None:
"""LLM test should restore masked sensitive custom config values before invocation."""
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
provider = LlmProviderNames.VERTEX_AI.value
provider_name = LlmProviderNames.VERTEX_AI.value
default_model_name = "gemini-2.5-pro"
original_custom_config = {
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
@@ -719,11 +708,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
return ""
try:
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
provider=provider_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -742,14 +730,9 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
id=provider.id,
provider=provider_name,
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config={

View File

@@ -15,9 +15,11 @@ import pytest
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_auto_mode_models
from onyx.db.llm import update_default_provider
@@ -135,7 +137,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name=expected_default_model,
model_configurations=[], # No model configs provided
),
is_creation=True,
@@ -163,13 +164,8 @@ class TestAutoModeSyncFeature:
if mc.name in all_expected_models:
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
# Verify the default model was set correctly
assert (
provider.default_model_name == expected_default_model
), f"Default model should be '{expected_default_model}'"
# Step 4: Set the provider as default
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, expected_default_model, db_session)
# Step 5: Fetch the default provider and verify
default_model = fetch_default_llm_model(db_session)
@@ -238,7 +234,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -317,7 +312,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=False, # Not in auto mode initially
default_model_name="gpt-4",
model_configurations=initial_models,
),
is_creation=True,
@@ -326,13 +320,13 @@ class TestAutoModeSyncFeature:
)
# Verify initial state: all models are visible
provider = fetch_existing_llm_provider(
initial_provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
assert initial_provider is not None
assert initial_provider.is_auto_mode is False
for mc in provider.model_configurations:
for mc in initial_provider.model_configurations:
assert (
mc.is_visible is True
), f"Initial model '{mc.name}' should be visible"
@@ -344,12 +338,12 @@ class TestAutoModeSyncFeature:
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=initial_provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not changing API key
api_key_changed=False,
is_auto_mode=True, # Now enabling auto mode
default_model_name=auto_mode_default,
model_configurations=[], # Auto mode will sync from config
),
is_creation=False, # This is an update
@@ -360,15 +354,15 @@ class TestAutoModeSyncFeature:
# Step 3: Verify model visibility after auto mode transition
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
provider_view = fetch_llm_provider_view(
provider_name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is True
assert provider_view is not None
assert provider_view.is_auto_mode is True
# Build a map of model name -> visibility
model_visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
mc.name: mc.is_visible for mc in provider_view.model_configurations
}
# Models in auto mode config should be visible
@@ -388,9 +382,6 @@ class TestAutoModeSyncFeature:
model_visibility[model_name] is False
), f"Model '{model_name}' not in auto config should NOT be visible"
# Verify the default model was updated
assert provider.default_model_name == auto_mode_default
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
@@ -432,8 +423,12 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
)
],
),
is_creation=True,
_=_create_mock_admin(),
@@ -535,7 +530,6 @@ class TestAutoModeSyncFeature:
api_key=provider_1_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_1_default_model,
model_configurations=[],
),
is_creation=True,
@@ -549,7 +543,7 @@ class TestAutoModeSyncFeature:
name=provider_1_name, db_session=db_session
)
assert provider_1 is not None
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_default_model, db_session)
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
@@ -563,7 +557,6 @@ class TestAutoModeSyncFeature:
api_key=provider_2_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_2_default_model,
model_configurations=[],
),
is_creation=True,
@@ -584,7 +577,7 @@ class TestAutoModeSyncFeature:
name=provider_2_name, db_session=db_session
)
assert provider_2 is not None
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 5: Verify provider 2 is now the default
db_session.expire_all()
@@ -644,7 +637,6 @@ class TestAutoModeMissingFlows:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -701,3 +693,364 @@ class TestAutoModeMissingFlows:
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
class TestAutoModeTransitionsAndResync:
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Disabling auto mode should preserve the current model list and
prevent future syncs from altering visibility.
Steps:
1. Create provider in auto mode — models synced from config.
2. Update provider to manual mode (is_auto_mode=False).
3. Verify all models remain with unchanged visibility.
4. Call sync_auto_mode_models with a *different* config.
5. Verify fetch_auto_mode_providers excludes this provider, so the
periodic task would never call sync on it.
"""
initial_config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
try:
# Step 1: Create in auto mode
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=initial_config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility_before = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
# Step 2: Switch to manual mode
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
is_auto_mode=False,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
),
],
),
is_creation=False,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 3: Models unchanged
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
visibility_after = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_after == visibility_before
# Step 4-5: Provider excluded from auto mode queries
auto_providers = fetch_auto_mode_providers(db_session)
auto_provider_ids = {p.id for p in auto_providers}
assert provider.id not in auto_provider_ids
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_resync_adds_new_and_hides_removed_models(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the GitHub config changes between syncs, a subsequent sync
should add newly listed models and hide models that were removed.
Steps:
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
gpt-4-turbo added).
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
and visible.
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4-turbo"],
)
try:
# Step 1: Create with config v1
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Re-sync with config v2
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 3: Verify
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
# gpt-4o: still in config -> visible
assert visibility["gpt-4o"] is True
# gpt-4o-mini: removed from config -> hidden (not deleted)
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
assert visibility["gpt-4o-mini"] is False
# gpt-4-turbo: newly added -> visible
assert visibility["gpt-4-turbo"] is True
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_sync_is_idempotent(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Running sync twice with the same config should produce zero
changes on the second call."""
config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
# First explicit sync (may report changes if creation already synced)
sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
# Snapshot state after first sync
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
snapshot = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
# Second sync — should be a no-op
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
assert (
changes == 0
), f"Expected 0 changes on idempotent re-sync, got {changes}"
# State should be identical
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
current = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
assert current == snapshot
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_default_model_hidden_when_removed_from_config(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the current default model is removed from the config, sync
should hide it. The default model flow row should still exist (it
points at the ModelConfiguration), but the model is no longer visible.
Steps:
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
2. Set gpt-4o as the global default.
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
fetch_default_llm_model still returns a result (the flow row persists).
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o-mini",
additional_models=[],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Set gpt-4o as global default
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
update_default_provider(provider.id, "gpt-4o", db_session)
default_before = fetch_default_llm_model(db_session)
assert default_before is not None
assert default_before.name == "gpt-4o"
# Step 3: Re-sync with config v2 (gpt-4o removed)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 4: Verify visibility
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
# but the model is hidden. fetch_default_llm_model filters on
# is_visible=True, so it should NOT return gpt-4o.
db_session.expire_all()
default_after = fetch_default_llm_model(db_session)
assert (
default_after is None or default_after.name != "gpt-4o"
), "Hidden model should not be returned as the default"
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)

View File

@@ -64,7 +64,6 @@ def _create_provider(
name=name,
provider=provider,
api_key="sk-ant-api03-...",
default_model_name="claude-3-5-sonnet-20240620",
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -154,7 +153,9 @@ def test_user_sends_message_to_private_provider(
)
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
update_default_provider(public_provider_id, db_session)
update_default_provider(
public_provider_id, "claude-3-5-sonnet-20240620", db_session
)
try:
# Create chat session

View File

@@ -42,7 +42,6 @@ def _create_llm_provider_and_model(
name=provider_name,
provider="openai",
api_key="test-api-key",
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,

View File

@@ -434,7 +434,6 @@ class TestSlackBotFederatedSearch:
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
default_model_name="gpt-4o",
is_public=True,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -448,7 +447,7 @@ class TestSlackBotFederatedSearch:
db_session=db_session,
)
update_default_provider(provider_view.id, db_session)
update_default_provider(provider_view.id, "gpt-4o", db_session)
def _teardown_common_mocks(self, patches: list) -> None:
"""Stop all patches"""

View File

@@ -4,10 +4,12 @@ from uuid import uuid4
import requests
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -32,7 +34,6 @@ class LLMProviderManager:
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
provider=provider or LlmProviderNames.OPENAI,
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
@@ -65,7 +66,7 @@ class LLMProviderManager:
name=response_data["name"],
provider=response_data["provider"],
api_key=response_data["api_key"],
default_model_name=response_data["default_model_name"],
default_model_name=default_model_name or "gpt-4o-mini",
is_public=response_data["is_public"],
is_auto_mode=response_data.get("is_auto_mode", False),
groups=response_data["groups"],
@@ -75,9 +76,19 @@ class LLMProviderManager:
)
if set_as_default:
if default_model_name is None:
default_model_name = "gpt-4o-mini"
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers,
f"{API_SERVER_URL}/admin/llm/default",
json={
"provider_id": response_data["id"],
"model_name": default_model_name,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
set_default_response.raise_for_status()
@@ -104,7 +115,7 @@ class LLMProviderManager:
headers=user_performing_action.headers,
)
response.raise_for_status()
return [LLMProviderView(**ug) for ug in response.json()]
return [LLMProviderView(**p) for p in response.json()["providers"]]
@staticmethod
def verify(
@@ -113,7 +124,11 @@ class LLMProviderManager:
verify_deleted: bool = False,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
default_model = LLMProviderManager.get_default_model(user_performing_action)
for fetched_llm_provider in all_llm_providers:
model_names = [
model.name for model in fetched_llm_provider.model_configurations
]
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
@@ -126,11 +141,30 @@ class LLMProviderManager:
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and (
default_model is None or default_model.model_name in model_names
)
and llm_provider.is_public == fetched_llm_provider.is_public
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
):
return
if not verify_deleted:
raise ValueError(f"LLM Provider {llm_provider.id} not found")
@staticmethod
def get_default_model(
user_performing_action: DATestUser | None = None,
) -> DefaultModel | None:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
default_text = response.json().get("default_text")
if default_text is None:
return None
return DefaultModel(**default_text)

View File

@@ -128,7 +128,7 @@ class DATestLLMProvider(BaseModel):
name: str
provider: str
api_key: str
default_model_name: str
default_model_name: str | None = None
is_public: bool
is_auto_mode: bool = False
groups: list[int]

View File

@@ -42,12 +42,10 @@ def _create_provider_with_api(
llm_provider_data = {
"name": name,
"provider": provider_type,
"default_model_name": default_model,
"api_key": "test-api-key-for-auto-mode-testing",
"api_base": None,
"api_version": None,
"custom_config": None,
"fast_default_model_name": default_model,
"is_public": True,
"is_auto_mode": is_auto_mode,
"groups": [],
@@ -72,7 +70,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
headers=admin_user.headers,
)
response.raise_for_status()
for provider in response.json():
for provider in response.json()["providers"]:
if provider["id"] == provider_id:
return provider
raise ValueError(f"Provider with id {provider_id} not found")
@@ -219,15 +217,6 @@ def test_auto_mode_provider_gets_synced_from_github_config(
"is_visible"
], "Outdated model should not be visible after sync"
# Verify default model was set from GitHub config
expected_default = (
default_model["name"] if isinstance(default_model, dict) else default_model
)
assert synced_provider["default_model_name"] == expected_default, (
f"Default model should be {expected_default}, "
f"got {synced_provider['default_model_name']}"
)
def test_manual_mode_provider_not_affected_by_auto_sync(
reset: None, # noqa: ARG001
@@ -273,7 +262,3 @@ def test_manual_mode_provider_not_affected_by_auto_sync(
f"Manual mode provider models should not change. "
f"Initial: {initial_models}, Current: {current_models}"
)
assert (
updated_provider["default_model_name"] == custom_model
), f"Manual mode default model should remain {custom_model}"

View File

@@ -6,20 +6,21 @@ from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import LLMModelFlow
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import ModelConfiguration
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_llm_for_persona
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.persona import PersonaManager
@@ -41,24 +42,30 @@ def _create_llm_provider(
is_public: bool,
is_default: bool,
) -> LLMProviderModel:
provider = LLMProviderModel(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=default_model_name,
deployment_name=None,
is_public=is_public,
# Use None instead of False to avoid unique constraint violation
# The is_default_provider column has unique=True, so only one True and one False allowed
is_default_provider=is_default if is_default else None,
is_default_vision_provider=False,
default_vision_model=None,
_provider = upsert_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name,
is_visible=True,
)
],
),
db_session=db_session,
)
db_session.add(provider)
db_session.flush()
if is_default:
update_default_provider(_provider.id, default_model_name, db_session)
provider = db_session.get(LLMProviderModel, _provider.id)
if not provider:
raise ValueError(f"Provider {name} not found")
return provider
@@ -270,24 +277,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
provider_name=restricted_provider.name,
)
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
# resolve the default provider when the fallback path is triggered.
default_model_config = ModelConfiguration(
llm_provider_id=default_provider.id,
name=default_provider.default_model_name,
is_visible=True,
)
db_session.add(default_model_config)
db_session.flush()
db_session.add(
LLMModelFlow(
model_configuration_id=default_model_config.id,
llm_model_flow_type=LLMModelFlowType.CHAT,
is_default=True,
)
)
db_session.flush()
access_group = UserGroup(name="persona-group")
db_session.add(access_group)
db_session.flush()
@@ -321,13 +310,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
persona=persona,
user=admin_model,
)
assert allowed_llm.config.model_name == restricted_provider.default_model_name
assert (
allowed_llm.config.model_name
== restricted_provider.model_configurations[0].name
)
fallback_llm = get_llm_for_persona(
persona=persona,
user=basic_model,
)
assert fallback_llm.config.model_name == default_provider.default_model_name
assert (
fallback_llm.config.model_name
== default_provider.model_configurations[0].name
)
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
@@ -346,6 +341,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
name="public-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)
@@ -365,7 +361,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
provider_names = [p["name"] for p in providers]
# Public provider should be visible
@@ -380,7 +376,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_providers = admin_response.json()["providers"]
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
@@ -396,6 +392,7 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
name="default-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)

View File

@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
# Should succeed
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should include the restricted provider since basic_user can access the persona
provider_names = [p["name"] for p in providers]
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
# Should succeed (persona_id=0 refers to default persona, which is public)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should NOT include the restricted provider since basic_user is not in group2
provider_names = [p["name"] for p in providers]
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
# Should succeed - admins can access any persona
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should include the restricted provider
provider_names = [p["name"] for p in providers]
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
# Should succeed
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should return the public provider
assert len(providers) > 0

View File

@@ -72,6 +72,9 @@ def test_cold_startup_default_assistant() -> None:
assert (
"read_file" in tool_names
), "Default assistant should have FileReaderTool attached"
assert (
"python" in tool_names
), "Default assistant should have PythonTool attached"
# Also verify by display names for clarity
assert (
@@ -86,8 +89,11 @@ def test_cold_startup_default_assistant() -> None:
assert (
"File Reader" in tool_display_names
), "Default assistant should have File Reader tool"
# Should have exactly 5 tools
assert (
len(tool_associations) == 5
), f"Default assistant should have exactly 5 tools attached, got {len(tool_associations)}"
"Code Interpreter" in tool_display_names
), "Default assistant should have Code Interpreter tool"
# Should have exactly 6 tools
assert (
len(tool_associations) == 6
), f"Default assistant should have exactly 6 tools attached, got {len(tool_associations)}"

View File

@@ -1,160 +0,0 @@
"""Integration test for the full user-file lifecycle in no-vector-DB mode.
Covers: upload → COMPLETED → unlink from project → delete → gone.
The entire lifecycle is handled by FastAPI BackgroundTasks (no Celery workers
needed). The conftest-level ``pytestmark`` ensures these tests are skipped
when the server is running with vector DB enabled.
"""
import time
from uuid import UUID
import requests
from onyx.db.enums import UserFileStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.project import ProjectManager
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
POLL_INTERVAL_SECONDS = 1
POLL_TIMEOUT_SECONDS = 30
def _poll_file_status(
file_id: UUID,
user: DATestUser,
target_status: UserFileStatus,
timeout: int = POLL_TIMEOUT_SECONDS,
) -> None:
"""Poll GET /user/projects/file/{file_id} until the file reaches *target_status*."""
deadline = time.time() + timeout
while time.time() < deadline:
resp = requests.get(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=user.headers,
)
if resp.ok:
status = resp.json().get("status")
if status == target_status.value:
return
time.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(
f"File {file_id} did not reach {target_status.value} within {timeout}s"
)
def _file_is_gone(file_id: UUID, user: DATestUser, timeout: int = 15) -> None:
"""Poll until GET /user/projects/file/{file_id} returns 404."""
deadline = time.time() + timeout
while time.time() < deadline:
resp = requests.get(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=user.headers,
)
if resp.status_code == 404:
return
time.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(
f"File {file_id} still accessible after {timeout}s (expected 404)"
)
def test_file_upload_process_delete_lifecycle(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""Full lifecycle: upload → COMPLETED → unlink → delete → 404.
Validates that the API server handles all background processing
(via FastAPI BackgroundTasks) without any Celery workers running.
"""
project = ProjectManager.create(
name="lifecycle-test", user_performing_action=admin_user
)
file_content = b"Integration test file content for lifecycle verification."
upload_result = ProjectManager.upload_files(
project_id=project.id,
files=[("lifecycle.txt", file_content)],
user_performing_action=admin_user,
)
assert upload_result.user_files, "Expected at least one file in upload response"
user_file = upload_result.user_files[0]
file_id = user_file.id
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
project_files = ProjectManager.get_project_files(project.id, admin_user)
assert any(
f.id == file_id for f in project_files
), "File should be listed in project files after processing"
# Unlink the file from the project so the delete endpoint will proceed
unlink_resp = requests.delete(
f"{API_SERVER_URL}/user/projects/{project.id}/files/{file_id}",
headers=admin_user.headers,
)
assert (
unlink_resp.status_code == 204
), f"Expected 204 on unlink, got {unlink_resp.status_code}: {unlink_resp.text}"
delete_resp = requests.delete(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=admin_user.headers,
)
assert (
delete_resp.ok
), f"Delete request failed: {delete_resp.status_code} {delete_resp.text}"
body = delete_resp.json()
assert (
body["has_associations"] is False
), f"File still has associations after unlink: {body}"
_file_is_gone(file_id, admin_user)
project_files_after = ProjectManager.get_project_files(project.id, admin_user)
assert not any(
f.id == file_id for f in project_files_after
), "Deleted file should not appear in project files"
def test_delete_blocked_while_associated(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""Deleting a file that still belongs to a project should return
has_associations=True without actually deleting the file."""
project = ProjectManager.create(
name="assoc-test", user_performing_action=admin_user
)
upload_result = ProjectManager.upload_files(
project_id=project.id,
files=[("assoc.txt", b"associated file content")],
user_performing_action=admin_user,
)
file_id = upload_result.user_files[0].id
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
# Attempt to delete while still linked
delete_resp = requests.delete(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=admin_user.headers,
)
assert delete_resp.ok
body = delete_resp.json()
assert body["has_associations"] is True, "Should report existing associations"
assert project.name in body["project_names"]
# File should still be accessible
get_resp = requests.get(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=admin_user.headers,
)
assert get_resp.status_code == 200, "File should still exist after blocked delete"

View File

@@ -9,6 +9,19 @@ from redis.exceptions import RedisError
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
# Fields we assert on across all tests
_ASSERT_FIELDS = {
"application_status",
"ee_features_enabled",
"seat_count",
"used_seats",
}
def _pick(settings: Settings) -> dict:
"""Extract only the fields under test from a Settings object."""
return settings.model_dump(include=_ASSERT_FIELDS)
@pytest.fixture
def base_settings() -> Settings:
@@ -27,17 +40,17 @@ class TestApplyLicenseStatusToSettings:
def test_enforcement_disabled_enables_ee_features(
self, base_settings: Settings
) -> None:
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled.
If we're running the EE apply function, EE code was loaded via
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES, so features should be on.
"""
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
assert base_settings.ee_features_enabled is False
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is True
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", True)
@@ -46,13 +59,60 @@ class TestApplyLicenseStatusToSettings:
from ee.onyx.server.settings.api import apply_license_status_to_settings
result = apply_license_status_to_settings(base_settings)
assert result.ee_features_enabled is True
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
}
@pytest.mark.parametrize(
"license_status,expected_app_status,expected_ee_enabled",
"license_status,used_seats,seats,expected",
[
(ApplicationStatus.GATED_ACCESS, ApplicationStatus.GATED_ACCESS, False),
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE, True),
(
ApplicationStatus.GATED_ACCESS,
3,
10,
{
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.ACTIVE,
3,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.ACTIVE,
10,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.GRACE_PERIOD,
3,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
],
)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -63,25 +123,80 @@ class TestApplyLicenseStatusToSettings:
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
license_status: ApplicationStatus | None,
expected_app_status: ApplicationStatus,
expected_ee_enabled: bool,
license_status: ApplicationStatus,
used_seats: int,
seats: int,
expected: dict,
base_settings: Settings,
) -> None:
"""Self-hosted: license status controls both application_status and ee_features_enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
if license_status is None:
mock_get_metadata.return_value = None
else:
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_get_metadata.return_value = mock_metadata
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_metadata.used_seats = used_seats
mock_metadata.seats = seats
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert result.application_status == expected_app_status
assert result.ee_features_enabled is expected_ee_enabled
assert _pick(result) == expected
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_seat_limit_exceeded_sets_status_and_counts(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Seat limit exceeded sets SEAT_LIMIT_EXCEEDED with counts, keeps EE enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = ApplicationStatus.ACTIVE
mock_metadata.used_seats = 15
mock_metadata.seats = 10
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.SEAT_LIMIT_EXCEEDED,
"ee_features_enabled": True,
"seat_count": 10,
"used_seats": 15,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_expired_license_takes_precedence_over_seat_limit(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Expired license (GATED_ACCESS) takes precedence over seat limit exceeded."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = ApplicationStatus.GATED_ACCESS
mock_metadata.used_seats = 15
mock_metadata.seats = 10
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -105,8 +220,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.GATED_ACCESS
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -130,8 +249,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@@ -150,8 +273,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.side_effect = RedisError("Connection failed")
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
class TestSettingsDefaultEEDisabled:

View File

@@ -1,166 +0,0 @@
"""Unit tests for stop_signal_checker and chat_processing_checker.
These modules are safety-critical — they control whether a chat stream
continues or stops. The tests use a simple in-memory CacheBackend stub
so no external services are needed.
"""
from uuid import uuid4
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.chat.chat_processing_checker import is_chat_session_processing
from onyx.chat.chat_processing_checker import set_processing_status
from onyx.chat.stop_signal_checker import FENCE_TTL
from onyx.chat.stop_signal_checker import is_connected
from onyx.chat.stop_signal_checker import reset_cancel_status
from onyx.chat.stop_signal_checker import set_fence
class _MemoryCacheBackend(CacheBackend):
"""Minimal in-memory CacheBackend for unit tests."""
def __init__(self) -> None:
self._store: dict[str, bytes] = {}
def get(self, key: str) -> bytes | None:
return self._store.get(key)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None, # noqa: ARG002
) -> None:
if isinstance(value, bytes):
self._store[key] = value
else:
self._store[key] = str(value).encode()
def delete(self, key: str) -> None:
self._store.pop(key, None)
def exists(self, key: str) -> bool:
return key in self._store
def expire(self, key: str, seconds: int) -> None:
pass
def ttl(self, key: str) -> int:
return -2 if key not in self._store else -1
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
raise NotImplementedError
def rpush(self, key: str, value: str | bytes) -> None:
raise NotImplementedError
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
raise NotImplementedError
# ── stop_signal_checker ──────────────────────────────────────────────
class TestSetFence:
def test_set_fence_true_creates_key(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, True)
assert not is_connected(sid, cache)
def test_set_fence_false_removes_key(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, True)
set_fence(sid, cache, False)
assert is_connected(sid, cache)
def test_set_fence_false_noop_when_absent(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, False)
assert is_connected(sid, cache)
def test_set_fence_uses_ttl(self) -> None:
"""Verify set_fence passes ex=FENCE_TTL to cache.set."""
calls: list[dict[str, object]] = []
cache = _MemoryCacheBackend()
original_set = cache.set
def tracking_set(
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
calls.append({"key": key, "ex": ex})
original_set(key, value, ex=ex)
cache.set = tracking_set # type: ignore[assignment]
set_fence(uuid4(), cache, True)
assert len(calls) == 1
assert calls[0]["ex"] == FENCE_TTL
class TestIsConnected:
def test_connected_when_no_fence(self) -> None:
cache = _MemoryCacheBackend()
assert is_connected(uuid4(), cache)
def test_disconnected_when_fence_set(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, True)
assert not is_connected(sid, cache)
def test_sessions_are_isolated(self) -> None:
cache = _MemoryCacheBackend()
sid1, sid2 = uuid4(), uuid4()
set_fence(sid1, cache, True)
assert not is_connected(sid1, cache)
assert is_connected(sid2, cache)
class TestResetCancelStatus:
def test_clears_fence(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_fence(sid, cache, True)
reset_cancel_status(sid, cache)
assert is_connected(sid, cache)
def test_noop_when_no_fence(self) -> None:
cache = _MemoryCacheBackend()
reset_cancel_status(uuid4(), cache)
# ── chat_processing_checker ──────────────────────────────────────────
class TestSetProcessingStatus:
def test_set_true_marks_processing(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_processing_status(sid, cache, True)
assert is_chat_session_processing(sid, cache)
def test_set_false_clears_processing(self) -> None:
cache = _MemoryCacheBackend()
sid = uuid4()
set_processing_status(sid, cache, True)
set_processing_status(sid, cache, False)
assert not is_chat_session_processing(sid, cache)
class TestIsChatSessionProcessing:
def test_not_processing_by_default(self) -> None:
cache = _MemoryCacheBackend()
assert not is_chat_session_processing(uuid4(), cache)
def test_sessions_are_isolated(self) -> None:
cache = _MemoryCacheBackend()
sid1, sid2 = uuid4(), uuid4()
set_processing_status(sid1, cache, True)
assert is_chat_session_processing(sid1, cache)
assert not is_chat_session_processing(sid2, cache)

View File

@@ -1,163 +0,0 @@
"""Unit tests for federated OAuth state generation and verification.
Uses unittest.mock to patch get_cache_backend so no external services
are needed. Verifies the generate -> verify round-trip, one-time-use
semantics, TTL propagation, and error handling.
"""
from unittest.mock import patch
import pytest
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
from onyx.federated_connectors.oauth_utils import generate_oauth_state
from onyx.federated_connectors.oauth_utils import OAUTH_STATE_TTL
from onyx.federated_connectors.oauth_utils import OAuthSession
from onyx.federated_connectors.oauth_utils import verify_oauth_state
class _MemoryCacheBackend(CacheBackend):
"""Minimal in-memory CacheBackend for unit tests."""
def __init__(self) -> None:
self._store: dict[str, bytes] = {}
self.set_calls: list[dict[str, object]] = []
def get(self, key: str) -> bytes | None:
return self._store.get(key)
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
self.set_calls.append({"key": key, "ex": ex})
if isinstance(value, bytes):
self._store[key] = value
else:
self._store[key] = str(value).encode()
def delete(self, key: str) -> None:
self._store.pop(key, None)
def exists(self, key: str) -> bool:
return key in self._store
def expire(self, key: str, seconds: int) -> None:
pass
def ttl(self, key: str) -> int:
return -2 if key not in self._store else -1
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
raise NotImplementedError
def rpush(self, key: str, value: str | bytes) -> None:
raise NotImplementedError
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
raise NotImplementedError
def _patched(cache: _MemoryCacheBackend): # type: ignore[no-untyped-def]
return patch(
"onyx.federated_connectors.oauth_utils.get_cache_backend",
return_value=cache,
)
class TestGenerateAndVerifyRoundTrip:
def test_round_trip_basic(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
state = generate_oauth_state(
federated_connector_id=42,
user_id="user-abc",
)
session = verify_oauth_state(state)
assert session.federated_connector_id == 42
assert session.user_id == "user-abc"
assert session.redirect_uri is None
assert session.additional_data == {}
def test_round_trip_with_all_fields(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
state = generate_oauth_state(
federated_connector_id=7,
user_id="user-xyz",
redirect_uri="https://example.com/callback",
additional_data={"scope": "read"},
)
session = verify_oauth_state(state)
assert session.federated_connector_id == 7
assert session.user_id == "user-xyz"
assert session.redirect_uri == "https://example.com/callback"
assert session.additional_data == {"scope": "read"}
class TestOneTimeUse:
def test_verify_deletes_state(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
state = generate_oauth_state(federated_connector_id=1, user_id="u")
verify_oauth_state(state)
with pytest.raises(ValueError, match="OAuth state not found"):
verify_oauth_state(state)
class TestTTLPropagation:
def test_default_ttl(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
generate_oauth_state(federated_connector_id=1, user_id="u")
assert len(cache.set_calls) == 1
assert cache.set_calls[0]["ex"] == OAUTH_STATE_TTL
def test_custom_ttl(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
generate_oauth_state(federated_connector_id=1, user_id="u", ttl=600)
assert cache.set_calls[0]["ex"] == 600
class TestVerifyInvalidState:
def test_missing_state_raises(self) -> None:
cache = _MemoryCacheBackend()
with _patched(cache):
state = generate_oauth_state(federated_connector_id=1, user_id="u")
# Manually clear the cache to simulate expiration
cache._store.clear()
with pytest.raises(ValueError, match="OAuth state not found"):
verify_oauth_state(state)
class TestOAuthSessionSerialization:
def test_to_dict_from_dict_round_trip(self) -> None:
session = OAuthSession(
federated_connector_id=5,
user_id="u-123",
redirect_uri="https://redir.example.com",
additional_data={"key": "val"},
)
d = session.to_dict()
restored = OAuthSession.from_dict(d)
assert restored.federated_connector_id == 5
assert restored.user_id == "u-123"
assert restored.redirect_uri == "https://redir.example.com"
assert restored.additional_data == {"key": "val"}
def test_from_dict_defaults(self) -> None:
minimal = {"federated_connector_id": 1, "user_id": "u"}
session = OAuthSession.from_dict(minimal)
assert session.redirect_uri is None
assert session.additional_data == {}

View File

@@ -44,7 +44,6 @@ def _build_provider_view(
id=1,
name="test-provider",
provider=provider,
default_model_name="test-model",
model_configurations=[
ModelConfigurationView(
name="test-model",
@@ -62,7 +61,6 @@ def _build_provider_view(
groups=[],
personas=[],
deployment_name=None,
default_vision_model=None,
)

View File

@@ -1,196 +0,0 @@
"""Tests for tool construction when DISABLE_VECTOR_DB is True.
Verifies that:
- SearchTool.is_available() returns False when vector DB is disabled
- OpenURLTool.is_available() returns False when vector DB is disabled
- The force-add SearchTool block is suppressed when DISABLE_VECTOR_DB
- FileReaderTool.is_available() returns True when vector DB is disabled
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
APP_CONFIGS_MODULE = "onyx.configs.app_configs"
FILE_READER_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
# ------------------------------------------------------------------
# SearchTool.is_available()
# ------------------------------------------------------------------
class TestSearchToolAvailability:
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
def test_unavailable_when_vector_db_disabled(self) -> None:
from onyx.tools.tool_implementations.search.search_tool import SearchTool
assert SearchTool.is_available(MagicMock()) is False
@patch("onyx.db.connector.check_user_files_exist", return_value=True)
@patch(
"onyx.tools.tool_implementations.search.search_tool.check_federated_connectors_exist",
return_value=False,
)
@patch(
"onyx.tools.tool_implementations.search.search_tool.check_connectors_exist",
return_value=False,
)
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
def test_available_when_vector_db_enabled_and_files_exist(
self,
mock_connectors: MagicMock, # noqa: ARG002
mock_federated: MagicMock, # noqa: ARG002
mock_user_files: MagicMock, # noqa: ARG002
) -> None:
from onyx.tools.tool_implementations.search.search_tool import SearchTool
assert SearchTool.is_available(MagicMock()) is True
# ------------------------------------------------------------------
# OpenURLTool.is_available()
# ------------------------------------------------------------------
class TestOpenURLToolAvailability:
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
def test_unavailable_when_vector_db_disabled(self) -> None:
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
assert OpenURLTool.is_available(MagicMock()) is False
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
def test_available_when_vector_db_enabled(self) -> None:
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
assert OpenURLTool.is_available(MagicMock()) is True
# ------------------------------------------------------------------
# FileReaderTool.is_available()
# ------------------------------------------------------------------
class TestFileReaderToolAvailability:
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", True)
def test_available_when_vector_db_disabled(self) -> None:
assert FileReaderTool.is_available(MagicMock()) is True
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", False)
def test_unavailable_when_vector_db_enabled(self) -> None:
assert FileReaderTool.is_available(MagicMock()) is False
# ------------------------------------------------------------------
# Force-add SearchTool suppression
# ------------------------------------------------------------------
class TestForceAddSearchToolGuard:
def test_force_add_block_checks_disable_vector_db(self) -> None:
"""The force-add SearchTool block in construct_tools should include
`not DISABLE_VECTOR_DB` so that forced search is also suppressed
without a vector DB."""
import inspect
from onyx.tools.tool_constructor import construct_tools
source = inspect.getsource(construct_tools)
assert "DISABLE_VECTOR_DB" in source, (
"construct_tools should reference DISABLE_VECTOR_DB "
"to suppress force-adding SearchTool"
)
# ------------------------------------------------------------------
# Persona API — _validate_vector_db_knowledge
# ------------------------------------------------------------------
class TestValidateVectorDbKnowledge:
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
True,
)
def test_rejects_document_set_ids(self) -> None:
from fastapi import HTTPException
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = [1]
request.hierarchy_node_ids = []
request.document_ids = []
with __import__("pytest").raises(HTTPException) as exc_info:
_validate_vector_db_knowledge(request)
assert exc_info.value.status_code == 400
assert "document sets" in exc_info.value.detail
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
True,
)
def test_rejects_hierarchy_node_ids(self) -> None:
from fastapi import HTTPException
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = []
request.hierarchy_node_ids = [1]
request.document_ids = []
with __import__("pytest").raises(HTTPException) as exc_info:
_validate_vector_db_knowledge(request)
assert exc_info.value.status_code == 400
assert "hierarchy nodes" in exc_info.value.detail
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
True,
)
def test_rejects_document_ids(self) -> None:
from fastapi import HTTPException
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = []
request.hierarchy_node_ids = []
request.document_ids = ["doc-abc"]
with __import__("pytest").raises(HTTPException) as exc_info:
_validate_vector_db_knowledge(request)
assert exc_info.value.status_code == 400
assert "documents" in exc_info.value.detail
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
True,
)
def test_allows_user_files_only(self) -> None:
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = []
request.hierarchy_node_ids = []
request.document_ids = []
_validate_vector_db_knowledge(request)
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
False,
)
def test_allows_everything_when_vector_db_enabled(self) -> None:
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = [1, 2]
request.hierarchy_node_ids = [3]
request.document_ids = ["doc-x"]
_validate_vector_db_knowledge(request)

View File

@@ -1,237 +0,0 @@
"""Tests for the FileReaderTool.
Verifies:
- Tool definition schema is well-formed
- File ID validation (allowlist, UUID format)
- Character range extraction and clamping
- Error handling for missing parameters and non-text files
- is_available() reflects DISABLE_VECTOR_DB
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolCallException
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FILE_ID_FIELD
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
from onyx.tools.tool_implementations.file_reader.file_reader_tool import MAX_NUM_CHARS
from onyx.tools.tool_implementations.file_reader.file_reader_tool import NUM_CHARS_FIELD
from onyx.tools.tool_implementations.file_reader.file_reader_tool import (
START_CHAR_FIELD,
)
TOOL_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
_PLACEMENT = Placement(turn_index=0)
def _make_tool(
user_file_ids: list | None = None,
chat_file_ids: list | None = None,
) -> FileReaderTool:
emitter = MagicMock()
return FileReaderTool(
tool_id=99,
emitter=emitter,
user_file_ids=user_file_ids or [],
chat_file_ids=chat_file_ids or [],
)
def _text_file(content: str, filename: str = "test.txt") -> InMemoryChatFile:
return InMemoryChatFile(
file_id="some-file-id",
content=content.encode("utf-8"),
file_type=ChatFileType.PLAIN_TEXT,
filename=filename,
)
# ------------------------------------------------------------------
# Tool metadata
# ------------------------------------------------------------------
class TestToolMetadata:
def test_tool_name(self) -> None:
tool = _make_tool()
assert tool.name == "read_file"
def test_tool_definition_schema(self) -> None:
tool = _make_tool()
defn = tool.tool_definition()
assert defn["type"] == "function"
func = defn["function"]
assert func["name"] == "read_file"
props = func["parameters"]["properties"]
assert FILE_ID_FIELD in props
assert START_CHAR_FIELD in props
assert NUM_CHARS_FIELD in props
assert func["parameters"]["required"] == [FILE_ID_FIELD]
# ------------------------------------------------------------------
# File ID validation
# ------------------------------------------------------------------
class TestFileIdValidation:
def test_rejects_invalid_uuid(self) -> None:
tool = _make_tool()
with pytest.raises(ToolCallException, match="Invalid file_id"):
tool._validate_file_id("not-a-uuid")
def test_rejects_file_not_in_allowlist(self) -> None:
tool = _make_tool(user_file_ids=[uuid4()])
other_id = uuid4()
with pytest.raises(ToolCallException, match="not in available files"):
tool._validate_file_id(str(other_id))
def test_accepts_user_file_id(self) -> None:
uid = uuid4()
tool = _make_tool(user_file_ids=[uid])
assert tool._validate_file_id(str(uid)) == uid
def test_accepts_chat_file_id(self) -> None:
cid = uuid4()
tool = _make_tool(chat_file_ids=[cid])
assert tool._validate_file_id(str(cid)) == cid
# ------------------------------------------------------------------
# run() — character range extraction
# ------------------------------------------------------------------
class TestRun:
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_returns_full_content_by_default(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
content = "Hello, world!"
mock_load_user_file.return_value = _text_file(content)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
resp = tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid)},
)
assert content in resp.llm_facing_response
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_respects_start_char_and_num_chars(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
content = "abcdefghijklmnop"
mock_load_user_file.return_value = _text_file(content)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
resp = tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid), START_CHAR_FIELD: 4, NUM_CHARS_FIELD: 6},
)
assert "efghij" in resp.llm_facing_response
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_clamps_num_chars_to_max(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
content = "x" * (MAX_NUM_CHARS + 500)
mock_load_user_file.return_value = _text_file(content)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
resp = tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: MAX_NUM_CHARS + 9999},
)
assert f"Characters 0-{MAX_NUM_CHARS}" in resp.llm_facing_response
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_includes_continuation_hint(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
content = "x" * 100
mock_load_user_file.return_value = _text_file(content)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
resp = tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: 10},
)
assert "use start_char=10 to continue reading" in resp.llm_facing_response
def test_raises_on_missing_file_id(self) -> None:
tool = _make_tool()
with pytest.raises(ToolCallException, match="Missing required"):
tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
)
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_raises_on_non_text_file(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
mock_load_user_file.return_value = InMemoryChatFile(
file_id="img",
content=b"\x89PNG",
file_type=ChatFileType.IMAGE,
filename="photo.png",
)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
with pytest.raises(ToolCallException, match="not a text file"):
tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid)},
)
# ------------------------------------------------------------------
# is_available()
# ------------------------------------------------------------------
class TestIsAvailable:
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", True)
def test_available_when_vector_db_disabled(self) -> None:
assert FileReaderTool.is_available(MagicMock()) is True
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", False)
def test_unavailable_when_vector_db_enabled(self) -> None:
assert FileReaderTool.is_available(MagicMock()) is False

View File

@@ -163,3 +163,16 @@ Add clear comments:
- Any TODOs you add in the code must be accompanied by either the name/username
of the owner of that TODO, or an issue number for an issue referencing that
piece of work.
- Avoid module-level logic that runs on import, which leads to import-time side
effects. Essentially every piece of meaningful logic should exist within some
function that has to be explicitly invoked. Acceptable exceptions to this may
include loading environment variables or setting up loggers.
- If you find yourself needing something like this, you may want that logic to
exist in a file dedicated for manual execution (contains `if __name__ ==
"__main__":`) which should not be imported by anything else.
- Related to the above, do not conflate Python scripts you intend to run from
the command line (contains `if __name__ == "__main__":`) with modules you
intend to import from elsewhere. If for some unlikely reason they have to be
the same file, any logic specific to executing the file (including imports)
should be contained in the `if __name__ == "__main__":` block.
- Generally these executable files exist in `backend/scripts/`.

View File

@@ -468,7 +468,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -16,15 +16,12 @@
# This overlay:
# - Moves Vespa (index), both model servers, and code-interpreter to profiles
# so they do not start by default
# - Moves the background worker to the "background" profile (the API server
# handles all background work via FastAPI BackgroundTasks)
# - Makes the depends_on references to removed services optional
# - Sets DISABLE_VECTOR_DB=true on the api_server
# - Makes the depends_on references to those services optional
# - Sets DISABLE_VECTOR_DB=true on backend services
#
# To selectively bring services back:
# --profile vectordb Vespa + indexing model server
# --profile inference Inference model server
# --profile background Background worker (Celery)
# --profile code-interpreter Code interpreter
# =============================================================================
@@ -46,10 +43,20 @@ services:
- DISABLE_VECTOR_DB=true
- FILE_STORE_BACKEND=postgres
# Move the background worker to a profile so it does not start by default.
# The API server handles all background work in NO_VECTOR_DB mode.
background:
profiles: ["background"]
depends_on:
index:
condition: service_started
required: false
indexing_model_server:
condition: service_started
required: false
inference_model_server:
condition: service_started
required: false
environment:
- DISABLE_VECTOR_DB=true
- FILE_STORE_BACKEND=postgres
# Move Vespa and indexing model server to a profile so they do not start.
index:

View File

@@ -293,7 +293,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -298,7 +298,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -335,7 +335,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -232,7 +232,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -520,7 +520,7 @@ services:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
test: ["CMD", "mc", "ready", "local"]
interval: 30s
timeout: 20s
retries: 3

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_beat.replicaCount) 0) }}
{{- if gt (int .Values.celery_beat.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_heavy.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_heavy.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_light.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_light.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_primary.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_primary.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_user_file_processing.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_user_file_processing.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -28,9 +28,7 @@ postgresql:
# -- Master toggle for vector database support. When false:
# - Sets DISABLE_VECTOR_DB=true on all backend pods
# - Skips the indexing model server deployment (embeddings not needed)
# - Skips ALL celery worker deployments (beat, primary, light, heavy,
# monitoring, user-file-processing, docprocessing, docfetching) — the
# API server handles background work via FastAPI BackgroundTasks
# - Skips docprocessing and docfetching celery workers
# - You should also set vespa.enabled=false and opensearch.enabled=false
# to prevent those subcharts from deploying
vectorDB:

View File

@@ -10,7 +10,7 @@ requires-python = ">=3.11"
dependencies = [
"aioboto3==15.1.0",
"cohere==5.6.1",
"fastapi==0.128.0",
"fastapi==0.133.1",
"google-cloud-aiplatform==1.121.0",
"google-genai==1.52.0",
"litellm==1.81.6",

View File

@@ -51,6 +51,7 @@ func NewRootCommand() *cobra.Command {
cmd.AddCommand(NewRunCICommand())
cmd.AddCommand(NewScreenshotDiffCommand())
cmd.AddCommand(NewWebCommand())
cmd.AddCommand(NewWhoisCommand())
return cmd
}

159
tools/ods/cmd/whois.go Normal file
View File

@@ -0,0 +1,159 @@
package cmd
import (
"fmt"
"os"
"regexp"
"strings"
"text/tabwriter"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/onyx-dot-app/onyx/tools/ods/internal/kube"
)
var safeIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_\-]+$`)
// NewWhoisCommand creates the whois command for looking up users/tenants.
func NewWhoisCommand() *cobra.Command {
var ctx string
cmd := &cobra.Command{
Use: "whois <email-fragment or tenant-id>",
Short: "Look up users and admins by email or tenant ID",
Long: `Look up tenant and user information from the data plane PostgreSQL database.
Requires: AWS SSO login, kubectl access to the EKS cluster.
Two modes (auto-detected):
Email fragment:
ods whois chris
→ Searches user_tenant_mapping for emails matching '%chris%'
Tenant ID:
ods whois tenant_abcd1234-...
→ Lists all admin emails in that tenant
Cluster connection is configured via KUBE_CTX_* environment variables.
Each variable is a space-separated tuple: "cluster region namespace"
export KUBE_CTX_DATA_PLANE="<cluster> <region> <namespace>"
export KUBE_CTX_CONTROL_PLANE="<cluster> <region> <namespace>"
etc...
Use -c to select which context (default: data_plane).`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
runWhois(args[0], ctx)
},
}
cmd.Flags().StringVarP(&ctx, "context", "c", "data_plane", "cluster context name (maps to KUBE_CTX_<NAME> env var)")
return cmd
}
func clusterFromEnv(name string) *kube.Cluster {
envKey := "KUBE_CTX_" + strings.ToUpper(name)
val := os.Getenv(envKey)
if val == "" {
log.Fatalf("Environment variable %s is not set.\n\nSet it as a space-separated tuple:\n export %s=\"<cluster> <region> <namespace>\"", envKey, envKey)
}
parts := strings.Fields(val)
if len(parts) != 3 {
log.Fatalf("%s must be a space-separated tuple of 3 values (cluster region namespace), got: %q", envKey, val)
}
return &kube.Cluster{Name: parts[0], Region: parts[1], Namespace: parts[2]}
}
// queryPod runs a SQL query via pginto on the given pod and returns cleaned output lines.
func queryPod(c *kube.Cluster, pod, sql string) []string {
raw, err := c.ExecOnPod(pod, "pginto", "-A", "-t", "-F", "\t", "-c", sql)
if err != nil {
log.Fatalf("Query failed: %v", err)
}
var lines []string
for _, line := range strings.Split(strings.TrimSpace(raw), "\n") {
line = strings.TrimSpace(line)
if line != "" && !strings.HasPrefix(line, "Connecting to ") {
lines = append(lines, line)
}
}
return lines
}
func runWhois(query string, ctx string) {
c := clusterFromEnv(ctx)
if err := c.EnsureContext(); err != nil {
log.Fatalf("Failed to ensure cluster context: %v", err)
}
log.Info("Finding api-server pod...")
pod, err := c.FindPod("api-server")
if err != nil {
log.Fatalf("Failed to find api-server pod: %v", err)
}
log.Debugf("Using pod: %s", pod)
if strings.HasPrefix(query, "tenant_") {
findAdminsByTenant(c, pod, query)
} else {
findByEmail(c, pod, query)
}
}
func findByEmail(c *kube.Cluster, pod, fragment string) {
fragment = strings.NewReplacer("'", "", `"`, "", `;`, "", `\`, `\\`, `%`, `\%`, `_`, `\_`).Replace(fragment)
sql := fmt.Sprintf(
`SELECT email, tenant_id, active FROM public.user_tenant_mapping WHERE email LIKE '%%%s%%' ORDER BY email;`,
fragment,
)
log.Infof("Searching for emails matching '%%%s%%'...", fragment)
lines := queryPod(c, pod, sql)
if len(lines) == 0 {
fmt.Println("No results found.")
return
}
fmt.Println()
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
_, _ = fmt.Fprintln(w, "EMAIL\tTENANT ID\tACTIVE")
_, _ = fmt.Fprintln(w, "-----\t---------\t------")
for _, line := range lines {
_, _ = fmt.Fprintln(w, line)
}
_ = w.Flush()
}
func findAdminsByTenant(c *kube.Cluster, pod, tenantID string) {
if !safeIdentifier.MatchString(tenantID) {
log.Fatalf("Invalid tenant ID: %q (must be alphanumeric, hyphens, underscores only)", tenantID)
}
sql := fmt.Sprintf(
`SELECT email FROM "%s"."user" WHERE role = 'ADMIN' AND is_active = true AND email NOT LIKE 'api_key__%%' ORDER BY email;`,
tenantID,
)
log.Infof("Fetching admin emails for %s...", tenantID)
lines := queryPod(c, pod, sql)
if len(lines) == 0 {
fmt.Println("No admin users found for this tenant.")
return
}
fmt.Println()
fmt.Println("EMAIL")
fmt.Println("-----")
for _, line := range lines {
fmt.Println(line)
}
}

View File

@@ -0,0 +1,90 @@
package kube
import (
"bytes"
"fmt"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
)
// Cluster holds the connection info for a Kubernetes cluster.
type Cluster struct {
Name string
Region string
Namespace string
}
// EnsureContext makes sure the cluster exists in kubeconfig, calling
// aws eks update-kubeconfig only if the context is missing.
func (c *Cluster) EnsureContext() error {
// Check if context already exists in kubeconfig
cmd := exec.Command("kubectl", "config", "get-contexts", c.Name, "--no-headers")
if err := cmd.Run(); err == nil {
log.Debugf("Context %s already exists, skipping aws eks update-kubeconfig", c.Name)
return nil
}
log.Infof("Context %s not found, fetching kubeconfig from AWS...", c.Name)
cmd = exec.Command("aws", "eks", "update-kubeconfig", "--region", c.Region, "--name", c.Name, "--alias", c.Name)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("aws eks update-kubeconfig failed: %w\n%s", err, string(out))
}
return nil
}
// kubectlArgs returns common kubectl flags to target this cluster without mutating global context.
func (c *Cluster) kubectlArgs() []string {
return []string{"--context", c.Name, "--namespace", c.Namespace}
}
// FindPod returns the name of the first Running/Ready pod matching the given substring.
func (c *Cluster) FindPod(substring string) (string, error) {
args := append(c.kubectlArgs(), "get", "po",
"--field-selector", "status.phase=Running",
"--no-headers",
"-o", "custom-columns=NAME:.metadata.name,READY:.status.conditions[?(@.type=='Ready')].status",
)
cmd := exec.Command("kubectl", args...)
out, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return "", fmt.Errorf("kubectl get po failed: %w\n%s", err, string(exitErr.Stderr))
}
return "", fmt.Errorf("kubectl get po failed: %w", err)
}
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
name, ready := fields[0], fields[1]
if strings.Contains(name, substring) && ready == "True" {
log.Debugf("Found pod: %s", name)
return name, nil
}
}
return "", fmt.Errorf("no ready pod found matching %q", substring)
}
// ExecOnPod runs a command on a pod and returns its stdout.
func (c *Cluster) ExecOnPod(pod string, command ...string) (string, error) {
args := append(c.kubectlArgs(), "exec", pod, "--")
args = append(args, command...)
log.Debugf("Running: kubectl %s", strings.Join(args, " "))
cmd := exec.Command("kubectl", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("kubectl exec failed: %w\n%s", err, stderr.String())
}
return stdout.String(), nil
}

15
uv.lock generated
View File

@@ -1688,17 +1688,18 @@ wheels = [
[[package]]
name = "fastapi"
version = "0.128.0"
version = "0.133.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "annotated-doc" },
{ name = "pydantic" },
{ name = "starlette" },
{ name = "typing-extensions" },
{ name = "typing-inspection" },
]
sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
sdist = { url = "https://files.pythonhosted.org/packages/22/6f/0eafed8349eea1fa462238b54a624c8b408cd1ba2795c8e64aa6c34f8ab7/fastapi-0.133.1.tar.gz", hash = "sha256:ed152a45912f102592976fde6cbce7dae1a8a1053da94202e51dd35d184fadd6", size = 378741, upload-time = "2026-02-25T18:18:17.398Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
{ url = "https://files.pythonhosted.org/packages/d2/c9/a175a7779f3599dfa4adfc97a6ce0e157237b3d7941538604aadaf97bfb6/fastapi-0.133.1-py3-none-any.whl", hash = "sha256:658f34ba334605b1617a65adf2ea6461901bdb9af3a3080d63ff791ecf7dc2e2", size = 109029, upload-time = "2026-02-25T18:18:18.578Z" },
]
[[package]]
@@ -4612,7 +4613,7 @@ requires-dist = [
{ name = "einops", marker = "extra == 'model-server'", specifier = "==0.8.1" },
{ name = "exa-py", marker = "extra == 'backend'", specifier = "==1.15.4" },
{ name = "faker", marker = "extra == 'dev'", specifier = "==40.1.2" },
{ name = "fastapi", specifier = "==0.128.0" },
{ name = "fastapi", specifier = "==0.133.1" },
{ name = "fastapi-limiter", marker = "extra == 'backend'", specifier = "==0.1.6" },
{ name = "fastapi-users", marker = "extra == 'backend'", specifier = "==15.0.4" },
{ name = "fastapi-users-db-sqlalchemy", marker = "extra == 'backend'", specifier = "==7.0.0" },
@@ -8079,14 +8080,14 @@ wheels = [
[[package]]
name = "werkzeug"
version = "3.1.5"
version = "3.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markupsafe" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" }
sdist = { url = "https://files.pythonhosted.org/packages/61/f1/ee81806690a87dab5f5653c1f146c92bc066d7f4cebc603ef88eb9e13957/werkzeug-3.1.6.tar.gz", hash = "sha256:210c6bede5a420a913956b4791a7f4d6843a43b6fcee4dfa08a65e93007d0d25", size = 864736, upload-time = "2026-02-19T15:17:18.884Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" },
{ url = "https://files.pythonhosted.org/packages/4d/ec/d58832f89ede95652fd01f4f24236af7d32b70cab2196dfcc2d2fd13c5c2/werkzeug-3.1.6-py3-none-any.whl", hash = "sha256:7ddf3357bb9564e407607f988f683d72038551200c704012bb9a4c523d42f131", size = 225166, upload-time = "2026-02-19T15:17:17.475Z" },
]
[[package]]

View File

@@ -1,11 +1,7 @@
import "@opal/components/buttons/Button/styles.css";
import "@opal/components/tooltip.css";
import {
Interactive,
type InteractiveBaseProps,
type InteractiveContainerWidthVariant,
} from "@opal/core";
import type { SizeVariant } from "@opal/shared";
import { Interactive, type InteractiveBaseProps } from "@opal/core";
import type { SizeVariant, WidthVariant } from "@opal/shared";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
@@ -91,7 +87,7 @@ type ButtonProps = InteractiveBaseProps &
tooltip?: string;
/** Width preset. `"auto"` shrink-wraps, `"full"` stretches to parent width. */
width?: InteractiveContainerWidthVariant;
width?: WidthVariant;
/** Which side the tooltip appears on. */
tooltipSide?: TooltipSide;

View File

@@ -1,6 +1,6 @@
/* Hoverable — item transitions */
.hoverable-item {
transition: opacity 200ms ease-in-out;
transition: opacity 150ms ease-in-out;
}
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {

Some files were not shown because too many files have changed in this diff Show More