mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-28 05:05:48 +00:00
Compare commits
13 Commits
feat/remov
...
embed_imag
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
edb69df30c | ||
|
|
6c8d088789 | ||
|
|
cdf2bfeb46 | ||
|
|
98aea79433 | ||
|
|
d2deefd1f1 | ||
|
|
18b90d405d | ||
|
|
8394e8837b | ||
|
|
f06df891c4 | ||
|
|
d6d5e72c18 | ||
|
|
449f5d62f9 | ||
|
|
4d256c5666 | ||
|
|
2e53496f46 | ||
|
|
63a206706a |
@@ -1,37 +0,0 @@
|
||||
"""add cache_store table
|
||||
|
||||
Revision ID: 2664261bfaab
|
||||
Revises: c0c937d5c9e5
|
||||
Create Date: 2026-02-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2664261bfaab"
|
||||
down_revision = "c0c937d5c9e5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"cache_store",
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("key"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_cache_store_expires",
|
||||
"cache_store",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("expires_at IS NOT NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cache_store_expires", table_name="cache_store")
|
||||
op.drop_table("cache_store")
|
||||
@@ -241,7 +241,8 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"migrate-chunks-from-vespa-to-opensearch",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
|
||||
@@ -414,31 +414,34 @@ def _process_user_file_with_indexing(
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
def _process_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for processing a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips all Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"_process_user_file_impl - Starting id={user_file_id}")
|
||||
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
|
||||
start = time.monotonic()
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"_process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
documents: list[Document] = []
|
||||
try:
|
||||
@@ -446,15 +449,15 @@ def _process_user_file_impl(
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not uf:
|
||||
task_logger.warning(
|
||||
f"_process_user_file_impl - UserFile not found id={user_file_id}"
|
||||
f"process_single_user_file - UserFile not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
if uf.status != UserFileStatus.PROCESSING:
|
||||
task_logger.info(
|
||||
f"_process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
|
||||
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
@@ -468,6 +471,7 @@ def _process_user_file_impl(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
@@ -489,8 +493,9 @@ def _process_user_file_impl(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"_process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
@@ -499,42 +504,33 @@ def _process_user_file_impl(
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return
|
||||
return None
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"_process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Attempt to mark the file as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
# don't update the status if the user file is being deleted
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.exception(
|
||||
f"_process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
_process_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
soft_time_limit=300,
|
||||
@@ -585,38 +581,36 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _delete_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for deleting a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock (Celery path).
|
||||
When redis_locking=False, skips Redis operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"_delete_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
"""Process a single user file delete."""
|
||||
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"_delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
return None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"_delete_user_file_impl - User file not found id={user_file_id}"
|
||||
f"process_single_user_file_delete - User file not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -654,6 +648,7 @@ def _delete_user_file_impl(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
try:
|
||||
file_store.delete_file(user_file.file_id)
|
||||
@@ -661,33 +656,26 @@ def _delete_user_file_impl(
|
||||
user_file_id_to_plaintext_file_name(user_file.id)
|
||||
)
|
||||
except Exception as e:
|
||||
# This block executed only if the file is not found in the filestore
|
||||
task_logger.exception(
|
||||
f"_delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 3) Finally, delete the UserFile row
|
||||
db_session.delete(user_file)
|
||||
db_session.commit()
|
||||
task_logger.info(f"_delete_user_file_impl - Completed id={user_file_id}")
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Completed id={user_file_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"_delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
_delete_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -759,30 +747,32 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _project_sync_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for syncing a user file's project/persona metadata.
|
||||
"""Process a single user file project sync."""
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Starting id={user_file_id}"
|
||||
)
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"_project_sync_user_file_impl - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"_project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -793,10 +783,11 @@ def _project_sync_user_file_impl(
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"_project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -831,7 +822,7 @@ def _project_sync_user_file_impl(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"_project_sync_user_file_impl - User file id={user_file_id}"
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
@@ -844,21 +835,11 @@ def _project_sync_user_file_impl(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"_project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
_project_sync_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,276 +0,0 @@
|
||||
"""Periodic poller for NO_VECTOR_DB deployments.
|
||||
|
||||
Replaces Celery Beat and background workers with a lightweight daemon thread
|
||||
that runs from the API server process. Two responsibilities:
|
||||
|
||||
1. Recovery polling (every 30 s): re-processes user files stuck in
|
||||
PROCESSING / DELETING / needs_sync states via the drain loops defined
|
||||
in ``task_utils.py``.
|
||||
|
||||
2. Periodic task execution (configurable intervals): runs LLM model updates
|
||||
and scheduled evals at their configured cadences, with Postgres advisory
|
||||
lock deduplication across multiple API server instances.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECOVERY_INTERVAL_SECONDS = 30
|
||||
PERIODIC_TASK_LOCK_BASE = 20_000
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task definitions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PeriodicTaskDef:
|
||||
name: str
|
||||
interval_seconds: float
|
||||
lock_id: int
|
||||
run_fn: Callable[[], None]
|
||||
last_run_at: float = field(default=0.0)
|
||||
|
||||
|
||||
def _run_auto_llm_update() -> None:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
return
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_cache_cleanup() -> None:
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
|
||||
cleanup_expired_cache_entries()
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
|
||||
if not all(
|
||||
[
|
||||
BRAINTRUST_API_KEY,
|
||||
SCHEDULED_EVAL_PROJECT,
|
||||
SCHEDULED_EVAL_DATASET_NAMES,
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
]
|
||||
):
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
|
||||
try:
|
||||
run_eval(
|
||||
configuration=EvalConfigurationOptions(
|
||||
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=SCHEDULED_EVAL_PROJECT,
|
||||
experiment_name=f"{dataset_name} - {run_timestamp}",
|
||||
),
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
|
||||
)
|
||||
|
||||
|
||||
_CACHE_CLEANUP_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="cache-cleanup",
|
||||
interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 2,
|
||||
run_fn=_run_cache_cleanup,
|
||||
)
|
||||
)
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="auto-llm-update",
|
||||
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE,
|
||||
run_fn=_run_auto_llm_update,
|
||||
)
|
||||
)
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="scheduled-eval",
|
||||
interval_seconds=7 * 24 * 3600,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
|
||||
run_fn=_run_scheduled_eval,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task runner with advisory lock dedup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
|
||||
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
|
||||
now = time.monotonic()
|
||||
if now - task_def.last_run_at < task_def.interval_seconds:
|
||||
return
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
acquired = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": task_def.lock_id},
|
||||
).scalar()
|
||||
if not acquired:
|
||||
return
|
||||
|
||||
try:
|
||||
task_def.run_fn()
|
||||
task_def.last_run_at = now
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Error running periodic task {task_def.name}"
|
||||
)
|
||||
finally:
|
||||
db_session.execute(
|
||||
text("SELECT pg_advisory_unlock(:id)"),
|
||||
{"id": task_def.lock_id},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recovery / drain loop runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_drain_loops(tenant_id: str) -> None:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
drain_processing_loop(tenant_id)
|
||||
drain_delete_loop(tenant_id)
|
||||
drain_project_sync_loop(tenant_id)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Startup recovery (10g)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def recover_stuck_user_files(tenant_id: str) -> None:
|
||||
"""Run all drain loops once to re-process files left in intermediate states.
|
||||
|
||||
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
|
||||
"""
|
||||
logger.info("recover_stuck_user_files - Checking for stuck user files")
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("recover_stuck_user_files - Error during recovery")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daemon thread (10f)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_shutdown_event = threading.Event()
|
||||
_poller_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
def _poller_loop(tenant_id: str) -> None:
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
periodic_tasks = _build_periodic_tasks()
|
||||
logger.info(
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
|
||||
f"{[t.name for t in periodic_tasks]}"
|
||||
)
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Periodic poller - Error in recovery polling")
|
||||
|
||||
for task_def in periodic_tasks:
|
||||
try:
|
||||
_try_run_periodic_task(task_def)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Unhandled error checking task {task_def.name}"
|
||||
)
|
||||
|
||||
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_periodic_poller(tenant_id: str) -> None:
|
||||
"""Start the periodic poller daemon thread."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
_shutdown_event.clear()
|
||||
_poller_thread = threading.Thread(
|
||||
target=_poller_loop,
|
||||
args=(tenant_id,),
|
||||
daemon=True,
|
||||
name="no-vectordb-periodic-poller",
|
||||
)
|
||||
_poller_thread.start()
|
||||
logger.info("Periodic poller thread started")
|
||||
|
||||
|
||||
def stop_periodic_poller() -> None:
|
||||
"""Signal the periodic poller to stop and wait for it to exit."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
if _poller_thread is None:
|
||||
return
|
||||
_shutdown_event.set()
|
||||
_poller_thread.join(timeout=10)
|
||||
if _poller_thread.is_alive():
|
||||
logger.warning("Periodic poller thread did not stop within timeout")
|
||||
_poller_thread = None
|
||||
logger.info("Periodic poller thread stopped")
|
||||
@@ -1,35 +1,3 @@
|
||||
"""Background task utilities.
|
||||
|
||||
Contains query-history report helpers (used by all deployment modes) and
|
||||
in-process background task execution helpers for NO_VECTOR_DB mode:
|
||||
|
||||
- Postgres advisory lock-based concurrency semaphore (10d)
|
||||
- Drain loops that process all pending user file work (10e)
|
||||
- Entry points wired to FastAPI BackgroundTasks (10c)
|
||||
|
||||
Advisory locks are session-level: they persist until explicitly released via
|
||||
``pg_advisory_unlock`` or until the DB connection closes. The semaphore
|
||||
session is kept open for the entire drain loop so the slot stays held, and
|
||||
released in a ``finally`` block before the connection returns to the pool.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query-history report helpers (pre-existing, used by all modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
QUERY_REPORT_NAME_PREFIX = "query-history"
|
||||
|
||||
|
||||
@@ -41,173 +9,3 @@ def construct_query_history_report_name(
|
||||
|
||||
def extract_task_id_from_query_history_report_name(name: str) -> str:
|
||||
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Postgres advisory lock semaphore (NO_VECTOR_DB mode)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
BACKGROUND_TASK_SLOT_BASE = 10_000
|
||||
BACKGROUND_TASK_MAX_CONCURRENCY = 4
|
||||
|
||||
|
||||
def try_acquire_semaphore_slot(db_session: Session) -> int | None:
|
||||
"""Try to acquire one of N advisory lock slots.
|
||||
|
||||
Returns the slot number (0-based) if acquired, ``None`` if all slots are
|
||||
taken. ``pg_try_advisory_lock`` is non-blocking — returns ``false``
|
||||
immediately when the lock is held by another session.
|
||||
"""
|
||||
for slot in range(BACKGROUND_TASK_MAX_CONCURRENCY):
|
||||
lock_id = BACKGROUND_TASK_SLOT_BASE + slot
|
||||
acquired = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": lock_id},
|
||||
).scalar()
|
||||
if acquired:
|
||||
return slot
|
||||
return None
|
||||
|
||||
|
||||
def release_semaphore_slot(db_session: Session, slot: int) -> None:
|
||||
"""Release a previously acquired advisory lock slot."""
|
||||
lock_id = BACKGROUND_TASK_SLOT_BASE + slot
|
||||
db_session.execute(
|
||||
text("SELECT pg_advisory_unlock(:id)"),
|
||||
{"id": lock_id},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Work-claiming helpers (FOR UPDATE SKIP LOCKED)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next file in PROCESSING status."""
|
||||
return db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.PROCESSING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
|
||||
|
||||
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next file in DELETING status."""
|
||||
return db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
|
||||
|
||||
def _claim_next_sync_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync."""
|
||||
return db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Drain loops — acquire a semaphore slot then process *all* pending work
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def drain_processing_loop(tenant_id: str) -> None:
|
||||
"""Process all pending PROCESSING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_process_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as sem_session:
|
||||
slot = try_acquire_semaphore_slot(sem_session)
|
||||
if slot is None:
|
||||
logger.info("drain_processing_loop - All semaphore slots taken, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
with get_session_with_current_tenant() as claim_session:
|
||||
file_id = _claim_next_processing_file(claim_session)
|
||||
if file_id is None:
|
||||
break
|
||||
_process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
finally:
|
||||
release_semaphore_slot(sem_session, slot)
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
"""Delete all pending DELETING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_delete_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as sem_session:
|
||||
slot = try_acquire_semaphore_slot(sem_session)
|
||||
if slot is None:
|
||||
logger.info("drain_delete_loop - All semaphore slots taken, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
with get_session_with_current_tenant() as claim_session:
|
||||
file_id = _claim_next_deleting_file(claim_session)
|
||||
if file_id is None:
|
||||
break
|
||||
_delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
finally:
|
||||
release_semaphore_slot(sem_session, slot)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
"""Sync all pending project/persona metadata for user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as sem_session:
|
||||
slot = try_acquire_semaphore_slot(sem_session)
|
||||
if slot is None:
|
||||
logger.info("drain_project_sync_loop - All semaphore slots taken, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
with get_session_with_current_tenant() as claim_session:
|
||||
file_id = _claim_next_sync_file(claim_session)
|
||||
if file_id is None:
|
||||
break
|
||||
_project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
finally:
|
||||
release_semaphore_slot(sem_session, slot)
|
||||
|
||||
51
backend/onyx/cache/factory.py
vendored
51
backend/onyx/cache/factory.py
vendored
@@ -1,51 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
|
||||
|
||||
def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
def _build_postgres_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
|
||||
return PostgresCacheBackend(tenant_id)
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
CacheBackendType.POSTGRES: _build_postgres_backend,
|
||||
}
|
||||
|
||||
|
||||
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
|
||||
"""Return a tenant-aware ``CacheBackend``.
|
||||
|
||||
If *tenant_id* is ``None``, the current tenant is read from the
|
||||
thread-local context variable (same behaviour as ``get_redis_client``).
|
||||
"""
|
||||
if tenant_id is None:
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
|
||||
if builder is None:
|
||||
raise ValueError(
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
|
||||
f"Supported values: {[t.value for t in CacheBackendType]}"
|
||||
)
|
||||
return builder(tenant_id)
|
||||
|
||||
|
||||
def get_shared_cache_backend() -> CacheBackend:
|
||||
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
|
||||
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)
|
||||
96
backend/onyx/cache/interface.py
vendored
96
backend/onyx/cache/interface.py
vendored
@@ -1,96 +0,0 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
|
||||
class CacheLock(abc.ABC):
|
||||
"""Abstract distributed lock returned by CacheBackend.lock()."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self) -> "CacheLock":
|
||||
self.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
|
||||
Covers the subset of Redis operations used outside of Celery. When
|
||||
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
|
||||
"""
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, key: str) -> bytes | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
"""Block until a value is available on one of *keys*, or *timeout* expires.
|
||||
|
||||
Returns ``(key, value)`` or ``None`` on timeout.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
293
backend/onyx/cache/postgres_backend.py
vendored
293
backend/onyx/cache/postgres_backend.py
vendored
@@ -1,293 +0,0 @@
|
||||
"""PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments.
|
||||
|
||||
Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks
|
||||
for distributed locking, and a polling loop for the BLPOP pattern.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import Connection
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
from onyx.db.models import CacheStore
|
||||
|
||||
_LIST_KEY_PREFIX = "_q:"
|
||||
_LIST_ITEM_TTL_SECONDS = 3600
|
||||
_BLPOP_POLL_INTERVAL = 0.25
|
||||
|
||||
|
||||
def _list_item_key(key: str) -> str:
|
||||
return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}"
|
||||
|
||||
|
||||
def _to_bytes(value: str | bytes | int | float) -> bytes:
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
return str(value).encode()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheLock(CacheLock):
|
||||
"""Advisory-lock-based distributed lock.
|
||||
|
||||
The lock is tied to a dedicated database connection. Releasing
|
||||
the lock (or closing the connection) frees it.
|
||||
|
||||
NOTE: Unlike Redis locks, advisory locks do not auto-expire after
|
||||
``timeout`` seconds. They are released when ``release()`` is
|
||||
called or when the underlying connection is closed.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_id: int, timeout: float | None) -> None:
|
||||
self._lock_id = lock_id
|
||||
self._timeout = timeout
|
||||
self._conn: Connection | None = None
|
||||
self._acquired = False
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
|
||||
self._conn = get_sqlalchemy_engine().connect()
|
||||
|
||||
if not blocking:
|
||||
if self._try_lock():
|
||||
return True
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
return False
|
||||
|
||||
effective_timeout = blocking_timeout or self._timeout
|
||||
deadline = (time.monotonic() + effective_timeout) if effective_timeout else None
|
||||
while True:
|
||||
if self._try_lock():
|
||||
return True
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
return False
|
||||
time.sleep(0.1)
|
||||
|
||||
def release(self) -> None:
|
||||
if not self._acquired or self._conn is None:
|
||||
return
|
||||
try:
|
||||
self._conn.execute(
|
||||
text("SELECT pg_advisory_unlock(:id)"), {"id": self._lock_id}
|
||||
)
|
||||
finally:
|
||||
self._acquired = False
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def owned(self) -> bool:
|
||||
return self._acquired
|
||||
|
||||
def _try_lock(self) -> bool:
|
||||
assert self._conn is not None
|
||||
result = self._conn.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"), {"id": self._lock_id}
|
||||
).scalar()
|
||||
if result:
|
||||
self._acquired = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backend
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class PostgresCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL.
|
||||
|
||||
Each operation opens and closes its own database session so the backend
|
||||
is safe to share across threads. Tenant isolation is handled by
|
||||
SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``).
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.value).where(
|
||||
CacheStore.key == key,
|
||||
or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()),
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
value = session.execute(stmt).scalar_one_or_none()
|
||||
if value is None:
|
||||
return None
|
||||
return bytes(value)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
value_bytes = _to_bytes(value)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=ex) if ex else None
|
||||
stmt = (
|
||||
pg_insert(CacheStore)
|
||||
.values(key=key, value=value_bytes, expires_at=expires_at)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[CacheStore.key],
|
||||
set_={"value": value_bytes, "expires_at": expires_at},
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(delete(CacheStore).where(CacheStore.key == key))
|
||||
session.commit()
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = (
|
||||
select(CacheStore.key)
|
||||
.where(
|
||||
CacheStore.key == key,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
return session.execute(stmt).first() is not None
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
||||
stmt = (
|
||||
update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
stmt = select(CacheStore.expires_at).where(CacheStore.key == key)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
result = session.execute(stmt).first()
|
||||
if result is None:
|
||||
return -2
|
||||
expires_at: datetime | None = result[0]
|
||||
if expires_at is None:
|
||||
return -1
|
||||
remaining = (expires_at - datetime.now(timezone.utc)).total_seconds()
|
||||
if remaining <= 0:
|
||||
return -2
|
||||
return int(remaining)
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return PostgresCacheLock(self._lock_id_for(name), timeout)
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
deadline = (time.monotonic() + timeout) if timeout > 0 else None
|
||||
while True:
|
||||
for key in keys:
|
||||
lower = f"{_LIST_KEY_PREFIX}{key}:"
|
||||
upper = f"{_LIST_KEY_PREFIX}{key};"
|
||||
stmt = (
|
||||
select(CacheStore)
|
||||
.where(
|
||||
CacheStore.key >= lower,
|
||||
CacheStore.key < upper,
|
||||
or_(
|
||||
CacheStore.expires_at.is_(None),
|
||||
CacheStore.expires_at > func.now(),
|
||||
),
|
||||
)
|
||||
.order_by(CacheStore.key)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as session:
|
||||
row = session.execute(stmt).scalars().first()
|
||||
if row is not None:
|
||||
value = bytes(row.value) if row.value else b""
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return (key.encode(), value)
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return None
|
||||
time.sleep(_BLPOP_POLL_INTERVAL)
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _lock_id_for(self, name: str) -> int:
|
||||
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
|
||||
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
|
||||
return struct.unpack("q", h[:8])[0]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_expired_cache_entries() -> None:
|
||||
"""Delete rows whose ``expires_at`` is in the past.
|
||||
|
||||
Called by the periodic poller every 5 minutes.
|
||||
"""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
session.execute(
|
||||
delete(CacheStore).where(
|
||||
CacheStore.expires_at.is_not(None),
|
||||
CacheStore.expires_at < func.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
92
backend/onyx/cache/redis_backend.py
vendored
92
backend/onyx/cache/redis_backend.py
vendored
@@ -1,92 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
|
||||
|
||||
class RedisCacheLock(CacheLock):
|
||||
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
|
||||
|
||||
def __init__(self, lock: RedisLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
return bool(
|
||||
self._lock.acquire(
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
)
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return bool(self._lock.owned())
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
|
||||
|
||||
This is a thin pass-through — every method maps 1-to-1 to the underlying
|
||||
Redis command. ``TenantRedis`` key-prefixing is handled by the client
|
||||
itself (provided by ``get_redis_client``).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._r = redis_client
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
val = self._r.get(key)
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val
|
||||
return str(val).encode()
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self._r.set(key, value, ex=ex)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._r.delete(key)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return bool(self._r.exists(key))
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
self._r.expire(key, seconds)
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return cast(int, self._r.ttl(key))
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return RedisCacheLock(self._r.lock(name, timeout=timeout))
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self._r.rpush(key, value)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1])
|
||||
@@ -6,7 +6,6 @@ from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -55,12 +54,6 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Which backend to use for caching, locks, and ephemeral state.
|
||||
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
|
||||
CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
|
||||
@@ -5000,25 +5000,3 @@ class CodeInterpreterServer(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
|
||||
class CacheStore(Base):
|
||||
"""Key-value cache table used by ``PostgresCacheBackend``.
|
||||
|
||||
Replaces Redis for simple KV caching, locks, and list operations
|
||||
when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments).
|
||||
|
||||
Intentionally separate from ``KVStore``:
|
||||
- Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics.
|
||||
- Has ``expires_at`` for TTL; rows are periodically garbage-collected.
|
||||
- Holds ephemeral data (tokens, stop signals, lock state) not
|
||||
persistent application config, so cleanup can be aggressive.
|
||||
"""
|
||||
|
||||
__tablename__ = "cache_store"
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -11,10 +10,7 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -109,8 +105,8 @@ def upload_files_to_user_files_with_indexing(
|
||||
user: User,
|
||||
temp_id_map: dict[str, str] | None,
|
||||
db_session: Session,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> CategorizedFilesResult:
|
||||
# Validate project ownership if a project_id is provided
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
@@ -131,27 +127,16 @@ def upload_files_to_user_files_with_indexing(
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
|
||||
@@ -32,14 +32,11 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import create_onyx_oauth_router
|
||||
from onyx.auth.users import fastapi_users
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -257,53 +254,8 @@ def include_auth_router_with_prefix(
|
||||
)
|
||||
|
||||
|
||||
def validate_cache_backend_settings() -> None:
|
||||
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
|
||||
|
||||
The Postgres cache backend eliminates the Redis dependency, but only works
|
||||
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
|
||||
"""
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
|
||||
raise RuntimeError(
|
||||
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
|
||||
"The Postgres cache backend is only supported in no-vector-DB "
|
||||
"deployments where Celery is replaced by the in-process task runner."
|
||||
)
|
||||
|
||||
|
||||
def validate_no_vector_db_settings() -> None:
|
||||
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
|
||||
|
||||
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
|
||||
since these modes require infrastructure that is removed in no-vector-DB deployments.
|
||||
"""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return
|
||||
|
||||
if MULTI_TENANT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
|
||||
"Multi-tenant deployments require the vector database for "
|
||||
"per-tenant document indexing and search. Run in single-tenant "
|
||||
"mode when disabling the vector database."
|
||||
)
|
||||
|
||||
from onyx.server.features.build.configs import ENABLE_CRAFT
|
||||
|
||||
if ENABLE_CRAFT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
|
||||
"Onyx Craft requires background workers for sandbox lifecycle "
|
||||
"management, which are removed in no-vector-DB deployments. "
|
||||
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||
@@ -372,20 +324,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await setup_auth_limiter()
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.background.periodic_poller import start_periodic_poller
|
||||
|
||||
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
|
||||
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
yield
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import stop_periodic_poller
|
||||
|
||||
stop_periodic_poller()
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
@@ -13,7 +12,13 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -29,6 +34,7 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -49,27 +55,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(
|
||||
user_file_id: UUID,
|
||||
tenant_id: str,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> None:
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
background_tasks.add_task(drain_project_sync_loop, tenant_id)
|
||||
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
|
||||
return
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
@@ -125,7 +111,6 @@ def create_project(
|
||||
|
||||
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_user_files(
|
||||
bg_tasks: BackgroundTasks,
|
||||
files: list[UploadFile] = File(...),
|
||||
project_id: int | None = Form(None),
|
||||
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
|
||||
@@ -152,12 +137,12 @@ def upload_user_files(
|
||||
user=user,
|
||||
temp_id_map=parsed_temp_id_map,
|
||||
db_session=db_session,
|
||||
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
|
||||
)
|
||||
|
||||
return CategorizedFilesSnapshot.from_result(categorized_files_result)
|
||||
|
||||
except Exception as e:
|
||||
# Log error with type, message, and stack for easier debugging
|
||||
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -207,7 +192,6 @@ def get_files_in_project(
|
||||
def unlink_user_file_from_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
@@ -224,6 +208,7 @@ def unlink_user_file_from_project(
|
||||
if project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = user.id
|
||||
user_file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
@@ -239,7 +224,7 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -252,7 +237,6 @@ def unlink_user_file_from_project(
|
||||
def link_user_file_to_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
@@ -284,7 +268,7 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
@@ -442,7 +426,6 @@ def delete_project(
|
||||
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
|
||||
def delete_user_file(
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileDeleteResult:
|
||||
@@ -475,25 +458,15 @@ def delete_user_file(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
|
||||
bg_tasks.add_task(drain_delete_loop, tenant_id)
|
||||
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
return UserFileDeleteResult(
|
||||
has_associations=False, project_names=[], assistant_names=[]
|
||||
)
|
||||
|
||||
@@ -257,7 +257,7 @@ exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
# fastmcp
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# fastapi-limiter
|
||||
# fastapi-users
|
||||
@@ -1155,6 +1155,7 @@ typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -125,7 +125,7 @@ executing==2.2.1
|
||||
# via stack-data
|
||||
faker==40.1.2
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
@@ -619,6 +619,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -90,7 +90,7 @@ docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
@@ -398,6 +398,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -108,7 +108,7 @@ durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
# sentry-sdk
|
||||
@@ -525,6 +525,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
"""External dependency unit tests for startup recovery (Step 10g).
|
||||
|
||||
Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING,
|
||||
needs_project_sync) then calls ``recover_stuck_user_files`` and verifies
|
||||
the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``.
|
||||
|
||||
Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures).
|
||||
The per-file ``_impl`` functions are mocked so no real file store or
|
||||
connector is needed — we only verify that recovery finds and dispatches
|
||||
the correct files.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _create_user_file(
|
||||
db_session: Session,
|
||||
user_id: object,
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
needs_project_sync: bool = False,
|
||||
needs_persona_sync: bool = False,
|
||||
) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=status,
|
||||
needs_project_sync=needs_project_sync,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
|
||||
"""Track created UserFile rows and delete them after each test."""
|
||||
created: list[UserFile] = []
|
||||
yield created
|
||||
for uf in created:
|
||||
existing = db_session.get(UserFile, uf.id)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecoverProcessingFiles:
|
||||
"""Files in PROCESSING status are re-processed via the processing drain loop."""
|
||||
|
||||
def test_processing_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_proc")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}._process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered but got: {called_ids}"
|
||||
|
||||
def test_completed_files_not_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_comp")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}._process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) not in called_ids
|
||||
), f"COMPLETED file {uf.id} should not have been recovered"
|
||||
|
||||
|
||||
class TestRecoverDeletingFiles:
|
||||
"""Files in DELETING status are recovered via the delete drain loop."""
|
||||
|
||||
def test_deleting_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_del")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}._delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoverSyncFiles:
|
||||
"""Files needing project/persona sync are recovered via the sync drain loop."""
|
||||
|
||||
def test_needs_project_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_sync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}._project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}"
|
||||
|
||||
def test_needs_persona_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_psync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}._project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoveryMultipleFiles:
|
||||
"""Recovery processes all stuck files in one pass, not just the first."""
|
||||
|
||||
def test_multiple_processing_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_multi")
|
||||
files = []
|
||||
for _ in range(3):
|
||||
uf = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}._process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list}
|
||||
expected_ids = {str(uf.id) for uf in files}
|
||||
assert expected_ids.issubset(called_ids), (
|
||||
f"Expected all {len(files)} files to be recovered. "
|
||||
f"Missing: {expected_ids - called_ids}"
|
||||
)
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Fixtures for cache backend tests.
|
||||
|
||||
Requires a running PostgreSQL instance (and Redis for parity tests).
|
||||
Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/cache/
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.models import CacheStore
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _init_db() -> Generator[None, None, None]:
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
CacheStore.__table__.create(get_sqlalchemy_engine(), checkfirst=True)
|
||||
yield
|
||||
with get_session_with_tenant(tenant_id=TEST_TENANT_ID) as session:
|
||||
session.execute(delete(CacheStore))
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tenant_context() -> Generator[None, None, None]:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_cache() -> PostgresCacheBackend:
|
||||
return PostgresCacheBackend(TEST_TENANT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_cache() -> RedisCacheBackend:
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(TEST_TENANT_ID))
|
||||
|
||||
|
||||
@pytest.fixture(params=["postgres", "redis"], ids=["postgres", "redis"])
|
||||
def cache(
|
||||
request: pytest.FixtureRequest,
|
||||
pg_cache: PostgresCacheBackend,
|
||||
redis_cache: RedisCacheBackend,
|
||||
) -> CacheBackend:
|
||||
if request.param == "postgres":
|
||||
return pg_cache
|
||||
return redis_cache
|
||||
@@ -1,98 +0,0 @@
|
||||
"""Parameterized tests that run the same CacheBackend operations against
|
||||
both Redis and PostgreSQL, asserting identical return values.
|
||||
|
||||
Each test runs twice (once per backend) via the ``cache`` fixture defined
|
||||
in conftest.py.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"parity_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
class TestKVParity:
|
||||
def test_get_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.get(_key()) is None
|
||||
|
||||
def test_get_set(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"value")
|
||||
assert cache.get(k) == b"value"
|
||||
|
||||
def test_overwrite(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"a")
|
||||
cache.set(k, b"b")
|
||||
assert cache.get(k) == b"b"
|
||||
|
||||
def test_set_string(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, "hello")
|
||||
assert cache.get(k) == b"hello"
|
||||
|
||||
def test_set_int(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, 42)
|
||||
assert cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
cache.delete(k)
|
||||
assert cache.get(k) is None
|
||||
|
||||
def test_exists(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not cache.exists(k)
|
||||
cache.set(k, b"x")
|
||||
assert cache.exists(k)
|
||||
|
||||
|
||||
class TestTTLParity:
|
||||
def test_ttl_missing(self, cache: CacheBackend) -> None:
|
||||
assert cache.ttl(_key()) == -2
|
||||
|
||||
def test_ttl_no_expiry(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x")
|
||||
assert cache.ttl(k) == -1
|
||||
|
||||
def test_ttl_remaining(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=10)
|
||||
remaining = cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_set_with_ttl_expires(self, cache: CacheBackend) -> None:
|
||||
k = _key()
|
||||
cache.set(k, b"x", ex=1)
|
||||
assert cache.get(k) == b"x"
|
||||
time.sleep(1.5)
|
||||
assert cache.get(k) is None
|
||||
|
||||
|
||||
class TestLockParity:
|
||||
def test_acquire_release(self, cache: CacheBackend) -> None:
|
||||
lock = cache.lock(f"parity_lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
|
||||
class TestListParity:
|
||||
def test_rpush_blpop(self, cache: CacheBackend) -> None:
|
||||
k = f"parity_list_{uuid4().hex[:8]}"
|
||||
cache.rpush(k, b"item")
|
||||
result = cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result[1] == b"item"
|
||||
|
||||
def test_blpop_timeout(self, cache: CacheBackend) -> None:
|
||||
result = cache.blpop([f"parity_empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
@@ -1,218 +0,0 @@
|
||||
"""Tests for PostgresCacheBackend against real PostgreSQL.
|
||||
|
||||
Covers every method on the backend: KV CRUD, TTL behaviour, advisory
|
||||
locks (acquire / release / contention), list operations (rpush / blpop),
|
||||
and the periodic cleanup function.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.cache.postgres_backend import cleanup_expired_cache_entries
|
||||
from onyx.cache.postgres_backend import PostgresCacheBackend
|
||||
|
||||
|
||||
def _key() -> str:
|
||||
return f"test_{uuid4().hex[:12]}"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Basic KV
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKV:
|
||||
def test_get_set(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"hello")
|
||||
assert pg_cache.get(k) == b"hello"
|
||||
|
||||
def test_get_missing(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.get(_key()) is None
|
||||
|
||||
def test_set_overwrite(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"first")
|
||||
pg_cache.set(k, b"second")
|
||||
assert pg_cache.get(k) == b"second"
|
||||
|
||||
def test_set_string_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, "string_val")
|
||||
assert pg_cache.get(k) == b"string_val"
|
||||
|
||||
def test_set_int_value(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, 42)
|
||||
assert pg_cache.get(k) == b"42"
|
||||
|
||||
def test_delete(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"to_delete")
|
||||
pg_cache.delete(k)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_delete_missing_is_noop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
pg_cache.delete(_key())
|
||||
|
||||
def test_exists(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
assert not pg_cache.exists(k)
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TTL
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTTL:
|
||||
def test_set_with_ttl_expires(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"ephemeral", ex=1)
|
||||
assert pg_cache.get(k) == b"ephemeral"
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.get(k) is None
|
||||
|
||||
def test_ttl_no_expiry(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"forever")
|
||||
assert pg_cache.ttl(k) == -1
|
||||
|
||||
def test_ttl_missing_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
assert pg_cache.ttl(_key()) == -2
|
||||
|
||||
def test_ttl_remaining(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=10)
|
||||
remaining = pg_cache.ttl(k)
|
||||
assert 8 <= remaining <= 10
|
||||
|
||||
def test_ttl_expired_key(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
time.sleep(1.5)
|
||||
assert pg_cache.ttl(k) == -2
|
||||
|
||||
def test_expire_adds_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x")
|
||||
assert pg_cache.ttl(k) == -1
|
||||
pg_cache.expire(k, 10)
|
||||
assert 8 <= pg_cache.ttl(k) <= 10
|
||||
|
||||
def test_exists_respects_ttl(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"x", ex=1)
|
||||
assert pg_cache.exists(k)
|
||||
time.sleep(1.5)
|
||||
assert not pg_cache.exists(k)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Locks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLock:
|
||||
def test_acquire_release(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
lock = pg_cache.lock(f"lock_{uuid4().hex[:8]}")
|
||||
assert lock.acquire(blocking=False)
|
||||
assert lock.owned()
|
||||
lock.release()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_contention(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"contention_{uuid4().hex[:8]}"
|
||||
lock1 = pg_cache.lock(name)
|
||||
lock2 = pg_cache.lock(name)
|
||||
|
||||
assert lock1.acquire(blocking=False)
|
||||
assert not lock2.acquire(blocking=False)
|
||||
|
||||
lock1.release()
|
||||
assert lock2.acquire(blocking=False)
|
||||
lock2.release()
|
||||
|
||||
def test_context_manager(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
with pg_cache.lock(f"ctx_{uuid4().hex[:8]}") as lock:
|
||||
assert lock.owned()
|
||||
assert not lock.owned()
|
||||
|
||||
def test_blocking_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
name = f"timeout_{uuid4().hex[:8]}"
|
||||
holder = pg_cache.lock(name)
|
||||
holder.acquire(blocking=False)
|
||||
|
||||
waiter = pg_cache.lock(name, timeout=0.3)
|
||||
start = time.monotonic()
|
||||
assert not waiter.acquire(blocking=True, blocking_timeout=0.3)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed >= 0.25
|
||||
|
||||
holder.release()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# List (rpush / blpop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestList:
|
||||
def test_rpush_blpop(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"list_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"item1")
|
||||
result = pg_cache.blpop([k], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k.encode(), b"item1")
|
||||
|
||||
def test_blpop_timeout(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
result = pg_cache.blpop([f"empty_{uuid4().hex[:8]}"], timeout=1)
|
||||
assert result is None
|
||||
|
||||
def test_fifo_order(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = f"fifo_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k, b"first")
|
||||
time.sleep(0.01)
|
||||
pg_cache.rpush(k, b"second")
|
||||
|
||||
r1 = pg_cache.blpop([k], timeout=1)
|
||||
r2 = pg_cache.blpop([k], timeout=1)
|
||||
assert r1 is not None and r1[1] == b"first"
|
||||
assert r2 is not None and r2[1] == b"second"
|
||||
|
||||
def test_multiple_keys(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k1 = f"mk1_{uuid4().hex[:8]}"
|
||||
k2 = f"mk2_{uuid4().hex[:8]}"
|
||||
pg_cache.rpush(k2, b"from_k2")
|
||||
|
||||
result = pg_cache.blpop([k1, k2], timeout=1)
|
||||
assert result is not None
|
||||
assert result == (k2.encode(), b"from_k2")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
def test_removes_expired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"stale", ex=1)
|
||||
time.sleep(1.5)
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.ttl(k) == -2
|
||||
|
||||
def test_preserves_unexpired_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"fresh", ex=300)
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"fresh"
|
||||
|
||||
def test_preserves_no_ttl_rows(self, pg_cache: PostgresCacheBackend) -> None:
|
||||
k = _key()
|
||||
pg_cache.set(k, b"permanent")
|
||||
cleanup_expired_cache_entries()
|
||||
assert pg_cache.get(k) == b"permanent"
|
||||
@@ -1,160 +0,0 @@
|
||||
"""Integration test for the full user-file lifecycle in no-vector-DB mode.
|
||||
|
||||
Covers: upload → COMPLETED → unlink from project → delete → gone.
|
||||
|
||||
The entire lifecycle is handled by FastAPI BackgroundTasks (no Celery workers
|
||||
needed). The conftest-level ``pytestmark`` ensures these tests are skipped
|
||||
when the server is running with vector DB enabled.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
POLL_INTERVAL_SECONDS = 1
|
||||
POLL_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
def _poll_file_status(
|
||||
file_id: UUID,
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus,
|
||||
timeout: int = POLL_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
"""Poll GET /user/projects/file/{file_id} until the file reaches *target_status*."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
if resp.ok:
|
||||
status = resp.json().get("status")
|
||||
if status == target_status.value:
|
||||
return
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(
|
||||
f"File {file_id} did not reach {target_status.value} within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
def _file_is_gone(file_id: UUID, user: DATestUser, timeout: int = 15) -> None:
|
||||
"""Poll until GET /user/projects/file/{file_id} returns 404."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(
|
||||
f"File {file_id} still accessible after {timeout}s (expected 404)"
|
||||
)
|
||||
|
||||
|
||||
def test_file_upload_process_delete_lifecycle(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Full lifecycle: upload → COMPLETED → unlink → delete → 404.
|
||||
|
||||
Validates that the API server handles all background processing
|
||||
(via FastAPI BackgroundTasks) without any Celery workers running.
|
||||
"""
|
||||
project = ProjectManager.create(
|
||||
name="lifecycle-test", user_performing_action=admin_user
|
||||
)
|
||||
|
||||
file_content = b"Integration test file content for lifecycle verification."
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("lifecycle.txt", file_content)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert upload_result.user_files, "Expected at least one file in upload response"
|
||||
|
||||
user_file = upload_result.user_files[0]
|
||||
file_id = user_file.id
|
||||
|
||||
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
|
||||
|
||||
project_files = ProjectManager.get_project_files(project.id, admin_user)
|
||||
assert any(
|
||||
f.id == file_id for f in project_files
|
||||
), "File should be listed in project files after processing"
|
||||
|
||||
# Unlink the file from the project so the delete endpoint will proceed
|
||||
unlink_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/{project.id}/files/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert (
|
||||
unlink_resp.status_code == 204
|
||||
), f"Expected 204 on unlink, got {unlink_resp.status_code}: {unlink_resp.text}"
|
||||
|
||||
delete_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert (
|
||||
delete_resp.ok
|
||||
), f"Delete request failed: {delete_resp.status_code} {delete_resp.text}"
|
||||
body = delete_resp.json()
|
||||
assert (
|
||||
body["has_associations"] is False
|
||||
), f"File still has associations after unlink: {body}"
|
||||
|
||||
_file_is_gone(file_id, admin_user)
|
||||
|
||||
project_files_after = ProjectManager.get_project_files(project.id, admin_user)
|
||||
assert not any(
|
||||
f.id == file_id for f in project_files_after
|
||||
), "Deleted file should not appear in project files"
|
||||
|
||||
|
||||
def test_delete_blocked_while_associated(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting a file that still belongs to a project should return
|
||||
has_associations=True without actually deleting the file."""
|
||||
project = ProjectManager.create(
|
||||
name="assoc-test", user_performing_action=admin_user
|
||||
)
|
||||
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("assoc.txt", b"associated file content")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
file_id = upload_result.user_files[0].id
|
||||
|
||||
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
|
||||
|
||||
# Attempt to delete while still linked
|
||||
delete_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_resp.ok
|
||||
body = delete_resp.json()
|
||||
assert body["has_associations"] is True, "Should report existing associations"
|
||||
assert project.name in body["project_names"]
|
||||
|
||||
# File should still be accessible
|
||||
get_resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert get_resp.status_code == 200, "File should still exist after blocked delete"
|
||||
@@ -1,291 +0,0 @@
|
||||
"""Tests for the _impl functions' redis_locking parameter.
|
||||
|
||||
Verifies that:
|
||||
- redis_locking=True acquires/releases Redis locks and clears queued keys
|
||||
- redis_locking=False skips all Redis operations entirely
|
||||
- Both paths execute the same business logic (DB lookup, status check)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_delete_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_process_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_project_sync_user_file_impl,
|
||||
)
|
||||
|
||||
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _mock_session_returning_none() -> MagicMock:
|
||||
"""Return a mock session whose .get() returns None (file not found)."""
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
return session
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _process_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
user_file_id = str(uuid4())
|
||||
_process_user_file_impl(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once_with(tenant_id="test-tenant")
|
||||
redis_client.delete.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
_process_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
_process_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_both_paths_call_db_get(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
"""Both redis_locking=True and False should call db_session.get(UserFile, ...)."""
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
uid = str(uuid4())
|
||||
|
||||
_process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=True)
|
||||
call_count_true = session.get.call_count
|
||||
|
||||
session.reset_mock()
|
||||
mock_get_session.reset_mock()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
_process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=False)
|
||||
call_count_false = session.get.call_count
|
||||
|
||||
assert call_count_true == call_count_false == 1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _delete_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
_delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
_delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
_delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _project_sync_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectSyncUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
_project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once()
|
||||
redis_client.delete.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
_project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
_project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
@@ -1,421 +0,0 @@
|
||||
"""Tests for no-vector-DB user file processing paths.
|
||||
|
||||
Verifies that when DISABLE_VECTOR_DB is True:
|
||||
- _process_user_file_impl calls _process_user_file_without_vector_db (not indexing)
|
||||
- _process_user_file_without_vector_db extracts text, counts tokens, stores plaintext,
|
||||
sets status=COMPLETED and chunk_count=0
|
||||
- _delete_user_file_impl skips vector DB chunk deletion
|
||||
- _project_sync_user_file_impl skips vector DB metadata update
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_delete_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_process_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_process_user_file_without_vector_db,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import UserFileStatus
|
||||
|
||||
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
LLM_FACTORY_MODULE = "onyx.llm.factory"
|
||||
|
||||
|
||||
def _make_documents(texts: list[str]) -> list[Document]:
|
||||
"""Build a list of Document objects with the given section texts."""
|
||||
return [
|
||||
Document(
|
||||
id=str(uuid4()),
|
||||
source=DocumentSource.USER_FILE,
|
||||
sections=[TextSection(text=t)],
|
||||
semantic_identifier=f"test-doc-{i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, t in enumerate(texts)
|
||||
]
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
file_id: str = "test-file-id",
|
||||
name: str = "test.txt",
|
||||
) -> MagicMock:
|
||||
"""Return a MagicMock mimicking a UserFile ORM instance."""
|
||||
uf = MagicMock()
|
||||
uf.id = uuid4()
|
||||
uf.file_id = file_id
|
||||
uf.name = name
|
||||
uf.status = status
|
||||
uf.token_count = None
|
||||
uf.chunk_count = None
|
||||
uf.last_project_sync_at = None
|
||||
uf.projects = []
|
||||
uf.assistants = []
|
||||
uf.needs_project_sync = True
|
||||
uf.needs_persona_sync = True
|
||||
return uf
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _process_user_file_without_vector_db — direct tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessUserFileWithoutVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_extracts_and_combines_text(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock,
|
||||
) -> None:
|
||||
mock_encode = MagicMock(return_value=[1, 2, 3, 4, 5])
|
||||
mock_get_encode.return_value = mock_encode
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["hello world", "foo bar"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
stored_text = mock_store_plaintext.call_args.kwargs["plaintext_content"]
|
||||
assert "hello world" in stored_text
|
||||
assert "foo bar" in stored_text
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_computes_token_count(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_encode = MagicMock(return_value=list(range(42)))
|
||||
mock_get_encode.return_value = mock_encode
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["some text content"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.token_count == 42
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_token_count_falls_back_to_none_on_error(
|
||||
self,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_encode: MagicMock, # noqa: ARG002
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_llm.side_effect = RuntimeError("No LLM configured")
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.token_count is None
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_stores_plaintext(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock,
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["content to store"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
mock_store_plaintext.assert_called_once_with(
|
||||
user_file_id=uf.id,
|
||||
plaintext_content="content to store",
|
||||
)
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_sets_completed_status_and_zero_chunk_count(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.status == UserFileStatus.COMPLETED
|
||||
assert uf.chunk_count == 0
|
||||
assert uf.last_project_sync_at is not None
|
||||
db_session.add.assert_called_once_with(uf)
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_preserves_deleting_status(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.status == UserFileStatus.DELETING
|
||||
assert uf.chunk_count == 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _process_user_file_impl — branching on DISABLE_VECTOR_DB
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessImplBranching:
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_calls_without_vector_db_when_disabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_with_indexing: MagicMock,
|
||||
mock_without_vdb: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
|
||||
|
||||
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
|
||||
_process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_without_vdb.assert_called_once()
|
||||
mock_with_indexing.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_calls_with_indexing_when_vector_db_enabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_with_indexing: MagicMock,
|
||||
mock_without_vdb: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
|
||||
|
||||
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
|
||||
_process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_with_indexing.assert_called_once()
|
||||
mock_without_vdb.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.run_indexing_pipeline")
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_indexing_pipeline_not_called_when_disabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
mock_run_pipeline: MagicMock,
|
||||
) -> None:
|
||||
"""End-to-end: verify run_indexing_pipeline is never invoked."""
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["content"])]
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock),
|
||||
patch(f"{LLM_FACTORY_MODULE}.get_default_llm"),
|
||||
patch(
|
||||
f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func",
|
||||
return_value=MagicMock(return_value=[1, 2, 3]),
|
||||
),
|
||||
):
|
||||
_process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_run_pipeline.assert_not_called()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _delete_user_file_impl — vector DB skip
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteImplNoVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_default_file_store")
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_skips_vector_db_deletion(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_file_store: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
mock_get_file_store.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
):
|
||||
_delete_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_indices.assert_not_called()
|
||||
mock_get_ss.assert_not_called()
|
||||
mock_vespa_pool.assert_not_called()
|
||||
|
||||
session.delete.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_default_file_store")
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_still_deletes_file_store_and_db_record(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_file_store: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
file_store = MagicMock()
|
||||
mock_get_file_store.return_value = file_store
|
||||
|
||||
_delete_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert file_store.delete_file.call_count == 2
|
||||
session.delete.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _project_sync_user_file_impl — vector DB skip
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectSyncImplNoVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_skips_vector_db_update(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
):
|
||||
_project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_indices.assert_not_called()
|
||||
mock_get_ss.assert_not_called()
|
||||
mock_vespa_pool.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_still_clears_sync_flags(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
_project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert uf.needs_project_sync is False
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.last_project_sync_at is not None
|
||||
session.add.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Tests for startup validation in no-vector-DB mode.
|
||||
|
||||
Verifies that DISABLE_VECTOR_DB raises RuntimeError when combined with
|
||||
incompatible settings (MULTI_TENANT, ENABLE_CRAFT).
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestValidateNoVectorDbSettings:
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", False)
|
||||
def test_no_error_when_vector_db_enabled(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", False)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", False)
|
||||
def test_no_error_when_no_conflicts(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", True)
|
||||
def test_raises_on_multi_tenant(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="MULTI_TENANT"):
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", False)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", True)
|
||||
def test_raises_on_enable_craft(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="ENABLE_CRAFT"):
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", True)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", True)
|
||||
def test_multi_tenant_checked_before_craft(self) -> None:
|
||||
"""MULTI_TENANT is checked first, so it should be the error raised."""
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="MULTI_TENANT"):
|
||||
validate_no_vector_db_settings()
|
||||
@@ -1,196 +0,0 @@
|
||||
"""Tests for tool construction when DISABLE_VECTOR_DB is True.
|
||||
|
||||
Verifies that:
|
||||
- SearchTool.is_available() returns False when vector DB is disabled
|
||||
- OpenURLTool.is_available() returns False when vector DB is disabled
|
||||
- The force-add SearchTool block is suppressed when DISABLE_VECTOR_DB
|
||||
- FileReaderTool.is_available() returns True when vector DB is disabled
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
|
||||
|
||||
APP_CONFIGS_MODULE = "onyx.configs.app_configs"
|
||||
FILE_READER_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SearchTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchToolAvailability:
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_unavailable_when_vector_db_disabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
assert SearchTool.is_available(MagicMock()) is False
|
||||
|
||||
@patch("onyx.db.connector.check_user_files_exist", return_value=True)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.check_federated_connectors_exist",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.check_connectors_exist",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_available_when_vector_db_enabled_and_files_exist(
|
||||
self,
|
||||
mock_connectors: MagicMock, # noqa: ARG002
|
||||
mock_federated: MagicMock, # noqa: ARG002
|
||||
mock_user_files: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
assert SearchTool.is_available(MagicMock()) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# OpenURLTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenURLToolAvailability:
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_unavailable_when_vector_db_disabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
|
||||
assert OpenURLTool.is_available(MagicMock()) is False
|
||||
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_available_when_vector_db_enabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
|
||||
assert OpenURLTool.is_available(MagicMock()) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FileReaderTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileReaderToolAvailability:
|
||||
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_available_when_vector_db_disabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is True
|
||||
|
||||
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_unavailable_when_vector_db_enabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Force-add SearchTool suppression
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForceAddSearchToolGuard:
|
||||
def test_force_add_block_checks_disable_vector_db(self) -> None:
|
||||
"""The force-add SearchTool block in construct_tools should include
|
||||
`not DISABLE_VECTOR_DB` so that forced search is also suppressed
|
||||
without a vector DB."""
|
||||
import inspect
|
||||
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
|
||||
source = inspect.getsource(construct_tools)
|
||||
assert "DISABLE_VECTOR_DB" in source, (
|
||||
"construct_tools should reference DISABLE_VECTOR_DB "
|
||||
"to suppress force-adding SearchTool"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Persona API — _validate_vector_db_knowledge
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateVectorDbKnowledge:
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_document_set_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = [1]
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = []
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "document sets" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_hierarchy_node_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = [1]
|
||||
request.document_ids = []
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "hierarchy nodes" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_document_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = ["doc-abc"]
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "documents" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_allows_user_files_only(self) -> None:
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = []
|
||||
|
||||
_validate_vector_db_knowledge(request)
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
False,
|
||||
)
|
||||
def test_allows_everything_when_vector_db_enabled(self) -> None:
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = [1, 2]
|
||||
request.hierarchy_node_ids = [3]
|
||||
request.document_ids = ["doc-x"]
|
||||
|
||||
_validate_vector_db_knowledge(request)
|
||||
@@ -1,237 +0,0 @@
|
||||
"""Tests for the FileReaderTool.
|
||||
|
||||
Verifies:
|
||||
- Tool definition schema is well-formed
|
||||
- File ID validation (allowlist, UUID format)
|
||||
- Character range extraction and clamping
|
||||
- Error handling for missing parameters and non-text files
|
||||
- is_available() reflects DISABLE_VECTOR_DB
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FILE_ID_FIELD
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import MAX_NUM_CHARS
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import NUM_CHARS_FIELD
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import (
|
||||
START_CHAR_FIELD,
|
||||
)
|
||||
|
||||
TOOL_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
|
||||
_PLACEMENT = Placement(turn_index=0)
|
||||
|
||||
|
||||
def _make_tool(
|
||||
user_file_ids: list | None = None,
|
||||
chat_file_ids: list | None = None,
|
||||
) -> FileReaderTool:
|
||||
emitter = MagicMock()
|
||||
return FileReaderTool(
|
||||
tool_id=99,
|
||||
emitter=emitter,
|
||||
user_file_ids=user_file_ids or [],
|
||||
chat_file_ids=chat_file_ids or [],
|
||||
)
|
||||
|
||||
|
||||
def _text_file(content: str, filename: str = "test.txt") -> InMemoryChatFile:
|
||||
return InMemoryChatFile(
|
||||
file_id="some-file-id",
|
||||
content=content.encode("utf-8"),
|
||||
file_type=ChatFileType.PLAIN_TEXT,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool metadata
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolMetadata:
|
||||
def test_tool_name(self) -> None:
|
||||
tool = _make_tool()
|
||||
assert tool.name == "read_file"
|
||||
|
||||
def test_tool_definition_schema(self) -> None:
|
||||
tool = _make_tool()
|
||||
defn = tool.tool_definition()
|
||||
assert defn["type"] == "function"
|
||||
func = defn["function"]
|
||||
assert func["name"] == "read_file"
|
||||
props = func["parameters"]["properties"]
|
||||
assert FILE_ID_FIELD in props
|
||||
assert START_CHAR_FIELD in props
|
||||
assert NUM_CHARS_FIELD in props
|
||||
assert func["parameters"]["required"] == [FILE_ID_FIELD]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# File ID validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileIdValidation:
|
||||
def test_rejects_invalid_uuid(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException, match="Invalid file_id"):
|
||||
tool._validate_file_id("not-a-uuid")
|
||||
|
||||
def test_rejects_file_not_in_allowlist(self) -> None:
|
||||
tool = _make_tool(user_file_ids=[uuid4()])
|
||||
other_id = uuid4()
|
||||
with pytest.raises(ToolCallException, match="not in available files"):
|
||||
tool._validate_file_id(str(other_id))
|
||||
|
||||
def test_accepts_user_file_id(self) -> None:
|
||||
uid = uuid4()
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
assert tool._validate_file_id(str(uid)) == uid
|
||||
|
||||
def test_accepts_chat_file_id(self) -> None:
|
||||
cid = uuid4()
|
||||
tool = _make_tool(chat_file_ids=[cid])
|
||||
assert tool._validate_file_id(str(cid)) == cid
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# run() — character range extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRun:
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_returns_full_content_by_default(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "Hello, world!"
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid)},
|
||||
)
|
||||
assert content in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_respects_start_char_and_num_chars(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "abcdefghijklmnop"
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), START_CHAR_FIELD: 4, NUM_CHARS_FIELD: 6},
|
||||
)
|
||||
assert "efghij" in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_clamps_num_chars_to_max(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "x" * (MAX_NUM_CHARS + 500)
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: MAX_NUM_CHARS + 9999},
|
||||
)
|
||||
assert f"Characters 0-{MAX_NUM_CHARS}" in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_includes_continuation_hint(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "x" * 100
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: 10},
|
||||
)
|
||||
assert "use start_char=10 to continue reading" in resp.llm_facing_response
|
||||
|
||||
def test_raises_on_missing_file_id(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException, match="Missing required"):
|
||||
tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
)
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_raises_on_non_text_file(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
mock_load_user_file.return_value = InMemoryChatFile(
|
||||
file_id="img",
|
||||
content=b"\x89PNG",
|
||||
file_type=ChatFileType.IMAGE,
|
||||
filename="photo.png",
|
||||
)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
with pytest.raises(ToolCallException, match="not a text file"):
|
||||
tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid)},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsAvailable:
|
||||
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_available_when_vector_db_disabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is True
|
||||
|
||||
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_unavailable_when_vector_db_enabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is False
|
||||
@@ -163,3 +163,16 @@ Add clear comments:
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username
|
||||
of the owner of that TODO, or an issue number for an issue referencing that
|
||||
piece of work.
|
||||
- Avoid module-level logic that runs on import, which leads to import-time side
|
||||
effects. Essentially every piece of meaningful logic should exist within some
|
||||
function that has to be explicitly invoked. Acceptable exceptions to this may
|
||||
include loading environment variables or setting up loggers.
|
||||
- If you find yourself needing something like this, you may want that logic to
|
||||
exist in a file dedicated for manual execution (contains `if __name__ ==
|
||||
"__main__":`) which should not be imported by anything else.
|
||||
- Related to the above, do not conflate Python scripts you intend to run from
|
||||
the command line (contains `if __name__ == "__main__":`) with modules you
|
||||
intend to import from elsewhere. If for some unlikely reason they have to be
|
||||
the same file, any logic specific to executing the file (including imports)
|
||||
should be contained in the `if __name__ == "__main__":` block.
|
||||
- Generally these executable files exist in `backend/scripts/`.
|
||||
|
||||
@@ -16,15 +16,12 @@
|
||||
# This overlay:
|
||||
# - Moves Vespa (index), both model servers, and code-interpreter to profiles
|
||||
# so they do not start by default
|
||||
# - Moves the background worker to the "background" profile (the API server
|
||||
# handles all background work via FastAPI BackgroundTasks)
|
||||
# - Makes the depends_on references to removed services optional
|
||||
# - Sets DISABLE_VECTOR_DB=true on the api_server
|
||||
# - Makes the depends_on references to those services optional
|
||||
# - Sets DISABLE_VECTOR_DB=true on backend services
|
||||
#
|
||||
# To selectively bring services back:
|
||||
# --profile vectordb Vespa + indexing model server
|
||||
# --profile inference Inference model server
|
||||
# --profile background Background worker (Celery)
|
||||
# --profile code-interpreter Code interpreter
|
||||
# =============================================================================
|
||||
|
||||
@@ -46,10 +43,20 @@ services:
|
||||
- DISABLE_VECTOR_DB=true
|
||||
- FILE_STORE_BACKEND=postgres
|
||||
|
||||
# Move the background worker to a profile so it does not start by default.
|
||||
# The API server handles all background work in NO_VECTOR_DB mode.
|
||||
background:
|
||||
profiles: ["background"]
|
||||
depends_on:
|
||||
index:
|
||||
condition: service_started
|
||||
required: false
|
||||
indexing_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
inference_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
environment:
|
||||
- DISABLE_VECTOR_DB=true
|
||||
- FILE_STORE_BACKEND=postgres
|
||||
|
||||
# Move Vespa and indexing model server to a profile so they do not start.
|
||||
index:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_beat.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_beat.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_heavy.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_heavy.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_light.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_light.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_primary.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_primary.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_user_file_processing.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_user_file_processing.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -28,9 +28,7 @@ postgresql:
|
||||
# -- Master toggle for vector database support. When false:
|
||||
# - Sets DISABLE_VECTOR_DB=true on all backend pods
|
||||
# - Skips the indexing model server deployment (embeddings not needed)
|
||||
# - Skips ALL celery worker deployments (beat, primary, light, heavy,
|
||||
# monitoring, user-file-processing, docprocessing, docfetching) — the
|
||||
# API server handles background work via FastAPI BackgroundTasks
|
||||
# - Skips docprocessing and docfetching celery workers
|
||||
# - You should also set vespa.enabled=false and opensearch.enabled=false
|
||||
# to prevent those subcharts from deploying
|
||||
vectorDB:
|
||||
|
||||
@@ -10,7 +10,7 @@ requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"aioboto3==15.1.0",
|
||||
"cohere==5.6.1",
|
||||
"fastapi==0.128.0",
|
||||
"fastapi==0.133.1",
|
||||
"google-cloud-aiplatform==1.121.0",
|
||||
"google-genai==1.52.0",
|
||||
"litellm==1.81.6",
|
||||
|
||||
@@ -51,6 +51,7 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewRunCICommand())
|
||||
cmd.AddCommand(NewScreenshotDiffCommand())
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
159
tools/ods/cmd/whois.go
Normal file
159
tools/ods/cmd/whois.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/kube"
|
||||
)
|
||||
|
||||
var safeIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_\-]+$`)
|
||||
|
||||
// NewWhoisCommand creates the whois command for looking up users/tenants.
|
||||
func NewWhoisCommand() *cobra.Command {
|
||||
var ctx string
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "whois <email-fragment or tenant-id>",
|
||||
Short: "Look up users and admins by email or tenant ID",
|
||||
Long: `Look up tenant and user information from the data plane PostgreSQL database.
|
||||
|
||||
Requires: AWS SSO login, kubectl access to the EKS cluster.
|
||||
|
||||
Two modes (auto-detected):
|
||||
|
||||
Email fragment:
|
||||
ods whois chris
|
||||
→ Searches user_tenant_mapping for emails matching '%chris%'
|
||||
|
||||
Tenant ID:
|
||||
ods whois tenant_abcd1234-...
|
||||
→ Lists all admin emails in that tenant
|
||||
|
||||
Cluster connection is configured via KUBE_CTX_* environment variables.
|
||||
Each variable is a space-separated tuple: "cluster region namespace"
|
||||
|
||||
export KUBE_CTX_DATA_PLANE="<cluster> <region> <namespace>"
|
||||
export KUBE_CTX_CONTROL_PLANE="<cluster> <region> <namespace>"
|
||||
etc...
|
||||
|
||||
Use -c to select which context (default: data_plane).`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runWhois(args[0], ctx)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVarP(&ctx, "context", "c", "data_plane", "cluster context name (maps to KUBE_CTX_<NAME> env var)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func clusterFromEnv(name string) *kube.Cluster {
|
||||
envKey := "KUBE_CTX_" + strings.ToUpper(name)
|
||||
val := os.Getenv(envKey)
|
||||
if val == "" {
|
||||
log.Fatalf("Environment variable %s is not set.\n\nSet it as a space-separated tuple:\n export %s=\"<cluster> <region> <namespace>\"", envKey, envKey)
|
||||
}
|
||||
|
||||
parts := strings.Fields(val)
|
||||
if len(parts) != 3 {
|
||||
log.Fatalf("%s must be a space-separated tuple of 3 values (cluster region namespace), got: %q", envKey, val)
|
||||
}
|
||||
|
||||
return &kube.Cluster{Name: parts[0], Region: parts[1], Namespace: parts[2]}
|
||||
}
|
||||
|
||||
// queryPod runs a SQL query via pginto on the given pod and returns cleaned output lines.
|
||||
func queryPod(c *kube.Cluster, pod, sql string) []string {
|
||||
raw, err := c.ExecOnPod(pod, "pginto", "-A", "-t", "-F", "\t", "-c", sql)
|
||||
if err != nil {
|
||||
log.Fatalf("Query failed: %v", err)
|
||||
}
|
||||
|
||||
var lines []string
|
||||
for _, line := range strings.Split(strings.TrimSpace(raw), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && !strings.HasPrefix(line, "Connecting to ") {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func runWhois(query string, ctx string) {
|
||||
c := clusterFromEnv(ctx)
|
||||
|
||||
if err := c.EnsureContext(); err != nil {
|
||||
log.Fatalf("Failed to ensure cluster context: %v", err)
|
||||
}
|
||||
|
||||
log.Info("Finding api-server pod...")
|
||||
pod, err := c.FindPod("api-server")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find api-server pod: %v", err)
|
||||
}
|
||||
log.Debugf("Using pod: %s", pod)
|
||||
|
||||
if strings.HasPrefix(query, "tenant_") {
|
||||
findAdminsByTenant(c, pod, query)
|
||||
} else {
|
||||
findByEmail(c, pod, query)
|
||||
}
|
||||
}
|
||||
|
||||
func findByEmail(c *kube.Cluster, pod, fragment string) {
|
||||
fragment = strings.NewReplacer("'", "", `"`, "", `;`, "", `\`, `\\`, `%`, `\%`, `_`, `\_`).Replace(fragment)
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT email, tenant_id, active FROM public.user_tenant_mapping WHERE email LIKE '%%%s%%' ORDER BY email;`,
|
||||
fragment,
|
||||
)
|
||||
|
||||
log.Infof("Searching for emails matching '%%%s%%'...", fragment)
|
||||
lines := queryPod(c, pod, sql)
|
||||
if len(lines) == 0 {
|
||||
fmt.Println("No results found.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "EMAIL\tTENANT ID\tACTIVE")
|
||||
_, _ = fmt.Fprintln(w, "-----\t---------\t------")
|
||||
for _, line := range lines {
|
||||
_, _ = fmt.Fprintln(w, line)
|
||||
}
|
||||
_ = w.Flush()
|
||||
}
|
||||
|
||||
func findAdminsByTenant(c *kube.Cluster, pod, tenantID string) {
|
||||
if !safeIdentifier.MatchString(tenantID) {
|
||||
log.Fatalf("Invalid tenant ID: %q (must be alphanumeric, hyphens, underscores only)", tenantID)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT email FROM "%s"."user" WHERE role = 'ADMIN' AND is_active = true AND email NOT LIKE 'api_key__%%' ORDER BY email;`,
|
||||
tenantID,
|
||||
)
|
||||
|
||||
log.Infof("Fetching admin emails for %s...", tenantID)
|
||||
lines := queryPod(c, pod, sql)
|
||||
if len(lines) == 0 {
|
||||
fmt.Println("No admin users found for this tenant.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("EMAIL")
|
||||
fmt.Println("-----")
|
||||
for _, line := range lines {
|
||||
fmt.Println(line)
|
||||
}
|
||||
}
|
||||
90
tools/ods/internal/kube/kube.go
Normal file
90
tools/ods/internal/kube/kube.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package kube
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Cluster holds the connection info for a Kubernetes cluster.
|
||||
type Cluster struct {
|
||||
Name string
|
||||
Region string
|
||||
Namespace string
|
||||
}
|
||||
|
||||
// EnsureContext makes sure the cluster exists in kubeconfig, calling
|
||||
// aws eks update-kubeconfig only if the context is missing.
|
||||
func (c *Cluster) EnsureContext() error {
|
||||
// Check if context already exists in kubeconfig
|
||||
cmd := exec.Command("kubectl", "config", "get-contexts", c.Name, "--no-headers")
|
||||
if err := cmd.Run(); err == nil {
|
||||
log.Debugf("Context %s already exists, skipping aws eks update-kubeconfig", c.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("Context %s not found, fetching kubeconfig from AWS...", c.Name)
|
||||
cmd = exec.Command("aws", "eks", "update-kubeconfig", "--region", c.Region, "--name", c.Name, "--alias", c.Name)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("aws eks update-kubeconfig failed: %w\n%s", err, string(out))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// kubectlArgs returns common kubectl flags to target this cluster without mutating global context.
|
||||
func (c *Cluster) kubectlArgs() []string {
|
||||
return []string{"--context", c.Name, "--namespace", c.Namespace}
|
||||
}
|
||||
|
||||
// FindPod returns the name of the first Running/Ready pod matching the given substring.
|
||||
func (c *Cluster) FindPod(substring string) (string, error) {
|
||||
args := append(c.kubectlArgs(), "get", "po",
|
||||
"--field-selector", "status.phase=Running",
|
||||
"--no-headers",
|
||||
"-o", "custom-columns=NAME:.metadata.name,READY:.status.conditions[?(@.type=='Ready')].status",
|
||||
)
|
||||
cmd := exec.Command("kubectl", args...)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return "", fmt.Errorf("kubectl get po failed: %w\n%s", err, string(exitErr.Stderr))
|
||||
}
|
||||
return "", fmt.Errorf("kubectl get po failed: %w", err)
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
name, ready := fields[0], fields[1]
|
||||
if strings.Contains(name, substring) && ready == "True" {
|
||||
log.Debugf("Found pod: %s", name)
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no ready pod found matching %q", substring)
|
||||
}
|
||||
|
||||
// ExecOnPod runs a command on a pod and returns its stdout.
|
||||
func (c *Cluster) ExecOnPod(pod string, command ...string) (string, error) {
|
||||
args := append(c.kubectlArgs(), "exec", pod, "--")
|
||||
args = append(args, command...)
|
||||
log.Debugf("Running: kubectl %s", strings.Join(args, " "))
|
||||
|
||||
cmd := exec.Command("kubectl", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("kubectl exec failed: %w\n%s", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
}
|
||||
9
uv.lock
generated
9
uv.lock
generated
@@ -1688,17 +1688,18 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
version = "0.133.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "annotated-doc" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "starlette" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "typing-inspection" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/6f/0eafed8349eea1fa462238b54a624c8b408cd1ba2795c8e64aa6c34f8ab7/fastapi-0.133.1.tar.gz", hash = "sha256:ed152a45912f102592976fde6cbce7dae1a8a1053da94202e51dd35d184fadd6", size = 378741, upload-time = "2026-02-25T18:18:17.398Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/c9/a175a7779f3599dfa4adfc97a6ce0e157237b3d7941538604aadaf97bfb6/fastapi-0.133.1-py3-none-any.whl", hash = "sha256:658f34ba334605b1617a65adf2ea6461901bdb9af3a3080d63ff791ecf7dc2e2", size = 109029, upload-time = "2026-02-25T18:18:18.578Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4612,7 +4613,7 @@ requires-dist = [
|
||||
{ name = "einops", marker = "extra == 'model-server'", specifier = "==0.8.1" },
|
||||
{ name = "exa-py", marker = "extra == 'backend'", specifier = "==1.15.4" },
|
||||
{ name = "faker", marker = "extra == 'dev'", specifier = "==40.1.2" },
|
||||
{ name = "fastapi", specifier = "==0.128.0" },
|
||||
{ name = "fastapi", specifier = "==0.133.1" },
|
||||
{ name = "fastapi-limiter", marker = "extra == 'backend'", specifier = "==0.1.6" },
|
||||
{ name = "fastapi-users", marker = "extra == 'backend'", specifier = "==15.0.4" },
|
||||
{ name = "fastapi-users-db-sqlalchemy", marker = "extra == 'backend'", specifier = "==7.0.0" },
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import "@opal/components/buttons/Button/styles.css";
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
type InteractiveBaseProps,
|
||||
type InteractiveContainerWidthVariant,
|
||||
} from "@opal/core";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import { Interactive, type InteractiveBaseProps } from "@opal/core";
|
||||
import type { SizeVariant, WidthVariant } from "@opal/shared";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
@@ -91,7 +87,7 @@ type ButtonProps = InteractiveBaseProps &
|
||||
tooltip?: string;
|
||||
|
||||
/** Width preset. `"auto"` shrink-wraps, `"full"` stretches to parent width. */
|
||||
width?: InteractiveContainerWidthVariant;
|
||||
width?: WidthVariant;
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
@@ -12,6 +12,5 @@ export {
|
||||
type InteractiveBaseProps,
|
||||
type InteractiveBaseVariantProps,
|
||||
type InteractiveContainerProps,
|
||||
type InteractiveContainerWidthVariant,
|
||||
type InteractiveContainerRoundingVariant,
|
||||
} from "@opal/core/interactive/components";
|
||||
|
||||
@@ -3,7 +3,12 @@ import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
import { sizeVariants, type SizeVariant } from "@opal/shared";
|
||||
import {
|
||||
sizeVariants,
|
||||
type SizeVariant,
|
||||
widthVariants,
|
||||
type WidthVariant,
|
||||
} from "@opal/shared";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -39,18 +44,6 @@ type InteractiveBaseVariantProps =
|
||||
selected?: never;
|
||||
};
|
||||
|
||||
/**
|
||||
* Width presets for `Interactive.Container`.
|
||||
*
|
||||
* - `"auto"` — Shrink-wraps to content width (default)
|
||||
* - `"full"` — Stretches to fill the parent's width (`w-full`)
|
||||
*/
|
||||
type InteractiveContainerWidthVariant = "auto" | "full";
|
||||
const interactiveContainerWidthVariants = {
|
||||
auto: "w-auto",
|
||||
full: "w-full",
|
||||
} as const;
|
||||
|
||||
/**
|
||||
* Border-radius presets for `Interactive.Container`.
|
||||
*
|
||||
@@ -345,7 +338,7 @@ interface InteractiveContainerProps
|
||||
*
|
||||
* @default "auto"
|
||||
*/
|
||||
widthVariant?: InteractiveContainerWidthVariant;
|
||||
widthVariant?: WidthVariant;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -413,7 +406,7 @@ function InteractiveContainer({
|
||||
height,
|
||||
minWidth,
|
||||
padding,
|
||||
interactiveContainerWidthVariants[widthVariant],
|
||||
widthVariants[widthVariant],
|
||||
slotClassName
|
||||
),
|
||||
"data-border": border ? ("true" as const) : undefined,
|
||||
@@ -490,6 +483,5 @@ export {
|
||||
type InteractiveBaseVariantProps,
|
||||
type InteractiveBaseSelectVariantProps,
|
||||
type InteractiveContainerProps,
|
||||
type InteractiveContainerWidthVariant,
|
||||
type InteractiveContainerRoundingVariant,
|
||||
};
|
||||
|
||||
@@ -14,6 +14,8 @@ import {
|
||||
} from "@opal/layouts/Content/LabelLayout";
|
||||
import type { TagProps } from "@opal/components/Tag/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { widthVariants, type WidthVariant } from "@opal/shared";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared types
|
||||
@@ -43,6 +45,17 @@ interface ContentBaseProps {
|
||||
|
||||
/** Called when the user commits an edit. */
|
||||
onTitleChange?: (newTitle: string) => void;
|
||||
|
||||
/**
|
||||
* Width preset controlling the component's horizontal size.
|
||||
* Uses the shared `WidthVariant` scale from `@opal/shared`.
|
||||
*
|
||||
* - `"auto"` — Shrink-wraps to content width
|
||||
* - `"full"` — Stretches to fill the parent's width
|
||||
*
|
||||
* @default "auto"
|
||||
*/
|
||||
widthVariant?: WidthVariant;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -87,11 +100,20 @@ type ContentProps = HeadingContentProps | LabelContentProps | BodyContentProps;
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function Content(props: ContentProps) {
|
||||
const { sizePreset = "headline", variant = "heading", ...rest } = props;
|
||||
const {
|
||||
sizePreset = "headline",
|
||||
variant = "heading",
|
||||
widthVariant = "auto",
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const widthClass = widthVariants[widthVariant];
|
||||
|
||||
let layout: React.ReactNode = null;
|
||||
|
||||
// Heading layout: headline/section presets with heading/section variant
|
||||
if (sizePreset === "headline" || sizePreset === "section") {
|
||||
return (
|
||||
layout = (
|
||||
<HeadingLayout
|
||||
sizePreset={sizePreset}
|
||||
variant={variant as HeadingLayoutProps["variant"]}
|
||||
@@ -101,8 +123,8 @@ function Content(props: ContentProps) {
|
||||
}
|
||||
|
||||
// Label layout: main-content/main-ui/secondary with section variant
|
||||
if (variant === "section" || variant === "heading") {
|
||||
return (
|
||||
else if (variant === "section" || variant === "heading") {
|
||||
layout = (
|
||||
<LabelLayout
|
||||
sizePreset={sizePreset}
|
||||
{...(rest as Omit<LabelLayoutProps, "sizePreset">)}
|
||||
@@ -111,8 +133,8 @@ function Content(props: ContentProps) {
|
||||
}
|
||||
|
||||
// Body layout: main-content/main-ui/secondary with body variant
|
||||
if (variant === "body") {
|
||||
return (
|
||||
else if (variant === "body") {
|
||||
layout = (
|
||||
<BodyLayout
|
||||
sizePreset={sizePreset}
|
||||
{...(rest as Omit<
|
||||
@@ -123,7 +145,17 @@ function Content(props: ContentProps) {
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
// This case should NEVER be hit.
|
||||
if (!layout)
|
||||
throw new Error(
|
||||
`Content: no layout matched for sizePreset="${sizePreset}" variant="${variant}"`
|
||||
);
|
||||
|
||||
// "auto" → return layout directly (a block div with w-auto still
|
||||
// stretches to its parent, defeating shrink-to-content).
|
||||
if (widthVariant === "auto") return layout;
|
||||
|
||||
return <div className={widthClass}>{layout}</div>;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -50,4 +50,31 @@ const sizeVariants = {
|
||||
/** Named size preset key. */
|
||||
type SizeVariant = keyof typeof sizeVariants;
|
||||
|
||||
export { sizeVariants, type SizeVariant };
|
||||
// ---------------------------------------------------------------------------
|
||||
// Width Variants
|
||||
//
|
||||
// A named scale of width presets that map to Tailwind width utility classes.
|
||||
//
|
||||
// Consumers:
|
||||
// - Interactive.Container (widthVariant)
|
||||
// - Button (width)
|
||||
// - Content (widthVariant)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Width-variant scale.
|
||||
*
|
||||
* | Key | Tailwind class |
|
||||
* |--------|----------------|
|
||||
* | `auto` | `w-auto` |
|
||||
* | `full` | `w-full` |
|
||||
*/
|
||||
const widthVariants = {
|
||||
auto: "w-auto",
|
||||
full: "w-full",
|
||||
} as const;
|
||||
|
||||
/** Named width preset key. */
|
||||
type WidthVariant = keyof typeof widthVariants;
|
||||
|
||||
export { sizeVariants, type SizeVariant, widthVariants, type WidthVariant };
|
||||
|
||||
@@ -10,7 +10,7 @@ export default function Main() {
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgMcp}
|
||||
title="MCP Actions"
|
||||
description="Connect MCP (Model Context Protocol) servers to add custom actions and tools for your assistants."
|
||||
description="Connect MCP (Model Context Protocol) servers to add custom actions and tools for your agents."
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
|
||||
@@ -10,7 +10,7 @@ export default function Main() {
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgActions}
|
||||
title="OpenAPI Actions"
|
||||
description="Connect OpenAPI servers to add custom actions and tools for your assistants."
|
||||
description="Connect OpenAPI servers to add custom actions and tools for your agents."
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
|
||||
@@ -170,7 +170,7 @@ export function PersonasTable({
|
||||
{deleteModalOpen && personaToDelete && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgAlertCircle}
|
||||
title="Delete Assistant"
|
||||
title="Delete Agent"
|
||||
onClose={closeDeleteModal}
|
||||
submit={<Button onClick={handleDeletePersona}>Delete</Button>}
|
||||
>
|
||||
@@ -183,15 +183,15 @@ export function PersonasTable({
|
||||
const isDefault = personaToToggleDefault.is_default_persona;
|
||||
|
||||
const title = isDefault
|
||||
? "Remove Featured Assistant"
|
||||
: "Set Featured Assistant";
|
||||
? "Remove Featured Agent"
|
||||
: "Set Featured Agent";
|
||||
const buttonText = isDefault ? "Remove Feature" : "Set as Featured";
|
||||
const text = isDefault
|
||||
? `Are you sure you want to remove the featured status of ${personaToToggleDefault.name}?`
|
||||
: `Are you sure you want to set the featured status of ${personaToToggleDefault.name}?`;
|
||||
const additionalText = isDefault
|
||||
? `Removing "${personaToToggleDefault.name}" as a featured assistant will not affect its visibility or accessibility.`
|
||||
: `Setting "${personaToToggleDefault.name}" as a featured assistant will make it public and visible to all users. This action cannot be undone.`;
|
||||
? `Removing "${personaToToggleDefault.name}" as a featured agent will not affect its visibility or accessibility.`
|
||||
: `Setting "${personaToToggleDefault.name}" as a featured agent will make it public and visible to all users. This action cannot be undone.`;
|
||||
|
||||
return (
|
||||
<ConfirmationModalLayout
|
||||
@@ -217,7 +217,7 @@ export function PersonasTable({
|
||||
"Name",
|
||||
"Description",
|
||||
"Type",
|
||||
"Featured Assistant",
|
||||
"Featured Agent",
|
||||
"Is Visible",
|
||||
"Delete",
|
||||
]}
|
||||
|
||||
@@ -47,8 +47,8 @@ function MainContent({
|
||||
return (
|
||||
<div>
|
||||
<Text className="mb-2">
|
||||
Assistants are a way to build custom search/question-answering
|
||||
experiences for different use cases.
|
||||
Agents are a way to build custom search/question-answering experiences
|
||||
for different use cases.
|
||||
</Text>
|
||||
<Text className="mt-2">They allow you to customize:</Text>
|
||||
<div className="text-sm">
|
||||
@@ -63,21 +63,21 @@ function MainContent({
|
||||
<div>
|
||||
<Separator />
|
||||
|
||||
<Title>Create an Assistant</Title>
|
||||
<Title>Create an Agent</Title>
|
||||
<CreateButton href="/app/agents/create?admin=true">
|
||||
New Assistant
|
||||
New Agent
|
||||
</CreateButton>
|
||||
|
||||
<Separator />
|
||||
|
||||
<Title>Existing Assistants</Title>
|
||||
<Title>Existing Agents</Title>
|
||||
{totalItems > 0 ? (
|
||||
<>
|
||||
<SubLabel>
|
||||
Assistants will be displayed as options on the Chat / Search
|
||||
interfaces in the order they are displayed below. Assistants
|
||||
marked as hidden will not be displayed. Editable assistants are
|
||||
shown at the top.
|
||||
Agents will be displayed as options on the Chat / Search
|
||||
interfaces in the order they are displayed below. Agents marked as
|
||||
hidden will not be displayed. Editable agents are shown at the
|
||||
top.
|
||||
</SubLabel>
|
||||
<PersonasTable
|
||||
personas={customPersonas}
|
||||
@@ -96,21 +96,21 @@ function MainContent({
|
||||
) : (
|
||||
<div className="mt-6 p-8 border border-border rounded-lg bg-background-weak text-center">
|
||||
<Text className="text-lg font-medium mb-2">
|
||||
No custom assistants yet
|
||||
No custom agents yet
|
||||
</Text>
|
||||
<Text className="text-subtle mb-3">
|
||||
Create your first assistant to:
|
||||
Create your first agent to:
|
||||
</Text>
|
||||
<ul className="text-subtle text-sm list-disc text-left inline-block mb-3">
|
||||
<li>Build department-specific knowledge bases</li>
|
||||
<li>Create specialized research assistants</li>
|
||||
<li>Create specialized research agents</li>
|
||||
<li>Set up compliance and policy advisors</li>
|
||||
</ul>
|
||||
<Text className="text-subtle text-sm mb-4">
|
||||
...and so much more!
|
||||
</Text>
|
||||
<CreateButton href="/app/agents/create?admin=true">
|
||||
Create Your First Assistant
|
||||
Create Your First Agent
|
||||
</CreateButton>
|
||||
</div>
|
||||
)}
|
||||
@@ -128,13 +128,13 @@ export default function Page() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={SvgOnyxOctagon} title="Assistants" />
|
||||
<AdminPageTitle icon={SvgOnyxOctagon} title="Agents" />
|
||||
|
||||
{isLoading && <ThreeDotsLoader />}
|
||||
|
||||
{error && (
|
||||
<ErrorCallout
|
||||
errorTitle="Failed to load assistants"
|
||||
errorTitle="Failed to load agents"
|
||||
errorMsg={
|
||||
error?.info?.message ||
|
||||
error?.info?.detail ||
|
||||
|
||||
@@ -156,7 +156,7 @@ export const SlackChannelConfigCreationForm = ({
|
||||
is: "assistant",
|
||||
then: (schema) =>
|
||||
schema.required(
|
||||
"A persona is required when using the'Assistant' knowledge source"
|
||||
"An agent is required when using the 'Agent' knowledge source"
|
||||
),
|
||||
}),
|
||||
standard_answer_categories: Yup.array(),
|
||||
|
||||
@@ -224,14 +224,14 @@ export function SlackChannelConfigFormFields({
|
||||
<RadioGroupItemField
|
||||
value="assistant"
|
||||
id="assistant"
|
||||
label="Search Assistant"
|
||||
label="Search Agent"
|
||||
sublabel="Control both the documents and the prompt to use for answering questions"
|
||||
/>
|
||||
<RadioGroupItemField
|
||||
value="non_search_assistant"
|
||||
id="non_search_assistant"
|
||||
label="Non-Search Assistant"
|
||||
sublabel="Chat with an assistant that does not use documents"
|
||||
label="Non-Search Agent"
|
||||
sublabel="Chat with an agent that does not use documents"
|
||||
/>
|
||||
</RadioGroup>
|
||||
</div>
|
||||
@@ -327,15 +327,15 @@ export function SlackChannelConfigFormFields({
|
||||
<div className="mt-4">
|
||||
<SubLabel>
|
||||
<>
|
||||
Select the search-enabled assistant OnyxBot will use while
|
||||
answering questions in Slack.
|
||||
Select the search-enabled agent OnyxBot will use while answering
|
||||
questions in Slack.
|
||||
{syncEnabledAssistants.length > 0 && (
|
||||
<>
|
||||
<br />
|
||||
<span className="text-sm text-text-dark/80">
|
||||
Note: Some of your assistants have auto-synced connectors
|
||||
in their document sets. You cannot select these assistants
|
||||
as they will not be able to answer questions in Slack.{" "}
|
||||
Note: Some of your agents have auto-synced connectors in
|
||||
their document sets. You cannot select these agents as
|
||||
they will not be able to answer questions in Slack.{" "}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
@@ -349,7 +349,7 @@ export function SlackChannelConfigFormFields({
|
||||
{viewSyncEnabledAssistants
|
||||
? "Hide un-selectable "
|
||||
: "View all "}
|
||||
assistants
|
||||
agents
|
||||
</button>
|
||||
</span>
|
||||
</>
|
||||
@@ -367,7 +367,7 @@ export function SlackChannelConfigFormFields({
|
||||
{viewSyncEnabledAssistants && syncEnabledAssistants.length > 0 && (
|
||||
<div className="mt-4">
|
||||
<p className="text-sm text-text-dark/80">
|
||||
Un-selectable assistants:
|
||||
Un-selectable agents:
|
||||
</p>
|
||||
<div className="mb-3 mt-2 flex gap-2 flex-wrap text-sm">
|
||||
{syncEnabledAssistants.map(
|
||||
@@ -394,15 +394,15 @@ export function SlackChannelConfigFormFields({
|
||||
<div className="mt-4">
|
||||
<SubLabel>
|
||||
<>
|
||||
Select the non-search assistant OnyxBot will use while answering
|
||||
Select the non-search agent OnyxBot will use while answering
|
||||
questions in Slack.
|
||||
{syncEnabledAssistants.length > 0 && (
|
||||
<>
|
||||
<br />
|
||||
<span className="text-sm text-text-dark/80">
|
||||
Note: Some of your assistants have auto-synced connectors
|
||||
in their document sets. You cannot select these assistants
|
||||
as they will not be able to answer questions in Slack.{" "}
|
||||
Note: Some of your agents have auto-synced connectors in
|
||||
their document sets. You cannot select these agents as
|
||||
they will not be able to answer questions in Slack.{" "}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
@@ -416,7 +416,7 @@ export function SlackChannelConfigFormFields({
|
||||
{viewSyncEnabledAssistants
|
||||
? "Hide un-selectable "
|
||||
: "View all "}
|
||||
assistants
|
||||
agents
|
||||
</button>
|
||||
</span>
|
||||
</>
|
||||
@@ -524,7 +524,7 @@ export function SlackChannelConfigFormFields({
|
||||
name="is_ephemeral"
|
||||
label="Respond to user in a private (ephemeral) message"
|
||||
tooltip="If set, OnyxBot will respond only to the user in a private (ephemeral) message. If you also
|
||||
chose 'Search' Assistant above, selecting this option will make documents that are private to the user
|
||||
chose 'Search' Agent above, selecting this option will make documents that are private to the user
|
||||
available for their queries."
|
||||
/>
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import CodeInterpreterPage from "@/refresh-pages/admin/CodeInterpreterPage";
|
||||
|
||||
export default function Page() {
|
||||
return <CodeInterpreterPage />;
|
||||
}
|
||||
@@ -39,10 +39,10 @@ export function AdvancedOptions({
|
||||
agents={agents}
|
||||
isLoading={agentsLoading}
|
||||
error={agentsError}
|
||||
label="Assistant Whitelist"
|
||||
subtext="Restrict this provider to specific assistants."
|
||||
label="Agent Whitelist"
|
||||
subtext="Restrict this provider to specific agents."
|
||||
disabled={formikProps.values.is_public}
|
||||
disabledMessage="This LLM Provider is public and available to all assistants."
|
||||
disabledMessage="This LLM Provider is public and available to all agents."
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
|
||||
@@ -299,11 +299,11 @@ export default function Page({ params }: Props) {
|
||||
});
|
||||
refreshGuild();
|
||||
toast.success(
|
||||
personaId ? "Default assistant updated" : "Default assistant cleared"
|
||||
personaId ? "Default agent updated" : "Default agent cleared"
|
||||
);
|
||||
} catch (err) {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to update assistant"
|
||||
err instanceof Error ? err.message : "Failed to update agent"
|
||||
);
|
||||
} finally {
|
||||
setIsUpdating(false);
|
||||
@@ -355,7 +355,7 @@ export default function Page({ params }: Props) {
|
||||
<InputSelect.Trigger placeholder="Select agent" />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="default">
|
||||
Default Assistant
|
||||
Default Agent
|
||||
</InputSelect.Item>
|
||||
{personas.map((persona) => (
|
||||
<InputSelect.Item
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState } from "react";
|
||||
import { FiDownload } from "react-icons/fi";
|
||||
import { memo, useState } from "react";
|
||||
import { SvgDownload } from "@opal/icons";
|
||||
import { ImageShape } from "@/app/app/services/streamingModels";
|
||||
import { FullImageModal } from "@/app/app/components/files/images/FullImageModal";
|
||||
import { buildImgUrl } from "@/app/app/components/files/images/utils";
|
||||
@@ -24,17 +24,22 @@ const SHAPE_CLASSES: Record<ImageShape, { container: string; image: string }> =
|
||||
},
|
||||
};
|
||||
|
||||
// Used to stop image flashing as images are loaded and response continues
|
||||
const loadedImages = new Set<string>();
|
||||
|
||||
interface InMessageImageProps {
|
||||
fileId: string;
|
||||
fileName?: string;
|
||||
shape?: ImageShape;
|
||||
}
|
||||
|
||||
export function InMessageImage({
|
||||
export const InMessageImage = memo(function InMessageImage({
|
||||
fileId,
|
||||
fileName,
|
||||
shape = DEFAULT_SHAPE,
|
||||
}: InMessageImageProps) {
|
||||
const [fullImageShowing, setFullImageShowing] = useState(false);
|
||||
const [imageLoaded, setImageLoaded] = useState(false);
|
||||
const [imageLoaded, setImageLoaded] = useState(loadedImages.has(fileId));
|
||||
|
||||
const normalizedShape = SHAPE_CLASSES[shape] ? shape : DEFAULT_SHAPE;
|
||||
const { container: shapeContainerClasses, image: shapeImageClasses } =
|
||||
@@ -49,7 +54,7 @@ export function InMessageImage({
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `image-${fileId}.png`; // You can adjust the filename/extension as needed
|
||||
a.download = fileName || `image-${fileId}.png`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
window.URL.revokeObjectURL(url);
|
||||
@@ -76,7 +81,10 @@ export function InMessageImage({
|
||||
width={1200}
|
||||
height={1200}
|
||||
alt="Chat Message Image"
|
||||
onLoad={() => setImageLoaded(true)}
|
||||
onLoad={() => {
|
||||
loadedImages.add(fileId);
|
||||
setImageLoaded(true);
|
||||
}}
|
||||
className={cn(
|
||||
"object-contain object-left overflow-hidden rounded-lg w-full h-full transition-opacity duration-300 cursor-pointer",
|
||||
shapeImageClasses,
|
||||
@@ -94,7 +102,7 @@ export function InMessageImage({
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
icon={FiDownload}
|
||||
icon={SvgDownload}
|
||||
tooltip="Download"
|
||||
onClick={handleDownload}
|
||||
/>
|
||||
@@ -102,4 +110,4 @@ export function InMessageImage({
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
const CHAT_FILE_URL_REGEX = /\/api\/chat\/file\/([^/?#]+)/;
|
||||
const IMAGE_EXTENSIONS = /\.(png|jpe?g|gif|webp|svg|bmp|ico|tiff?)$/i;
|
||||
|
||||
export function buildImgUrl(fileId: string) {
|
||||
return `/api/chat/file/${fileId}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* If `href` points to a chat file and `linkText` ends with an image extension,
|
||||
* returns the file ID. Otherwise returns null.
|
||||
*/
|
||||
export function extractChatImageFileId(
|
||||
href: string | undefined,
|
||||
linkText: string
|
||||
): string | null {
|
||||
if (!href) return null;
|
||||
const match = CHAT_FILE_URL_REGEX.exec(href);
|
||||
if (!match?.[1]) return null;
|
||||
if (!IMAGE_EXTENSIONS.test(linkText)) return null;
|
||||
return match[1];
|
||||
}
|
||||
|
||||
@@ -47,6 +47,8 @@ export interface RendererResult {
|
||||
|
||||
// Whether this renderer supports collapsible mode (collapse button shown only when true)
|
||||
supportsCollapsible?: boolean;
|
||||
/** Whether the step should remain collapsible even in single-step timelines */
|
||||
alwaysCollapsible?: boolean;
|
||||
/** Whether the result should be wrapped by timeline UI or rendered as-is */
|
||||
timelineLayout?: TimelineLayout;
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ import {
|
||||
import { extractCodeText, preprocessLaTeX } from "@/app/app/message/codeUtils";
|
||||
import { CodeBlock } from "@/app/app/message/CodeBlock";
|
||||
import { transformLinkUri, cn } from "@/lib/utils";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
|
||||
|
||||
/**
|
||||
* Processes content for markdown rendering by handling code blocks and LaTeX
|
||||
@@ -58,17 +60,31 @@ export const useMarkdownComponents = (
|
||||
);
|
||||
|
||||
const anchorCallback = useCallback(
|
||||
(props: any) => (
|
||||
<MemoizedAnchor
|
||||
updatePresentingDocument={state?.setPresentingDocument || (() => {})}
|
||||
docs={state?.docs || []}
|
||||
userFiles={state?.userFiles || []}
|
||||
citations={state?.citations}
|
||||
href={props.href}
|
||||
>
|
||||
{props.children}
|
||||
</MemoizedAnchor>
|
||||
),
|
||||
(props: any) => {
|
||||
const imageFileId = extractChatImageFileId(
|
||||
props.href,
|
||||
String(props.children ?? "")
|
||||
);
|
||||
if (imageFileId) {
|
||||
return (
|
||||
<InMessageImage
|
||||
fileId={imageFileId}
|
||||
fileName={String(props.children ?? "")}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<MemoizedAnchor
|
||||
updatePresentingDocument={state?.setPresentingDocument || (() => {})}
|
||||
docs={state?.docs || []}
|
||||
userFiles={state?.userFiles || []}
|
||||
citations={state?.citations}
|
||||
href={props.href}
|
||||
>
|
||||
{props.children}
|
||||
</MemoizedAnchor>
|
||||
);
|
||||
},
|
||||
[
|
||||
state?.docs,
|
||||
state?.userFiles,
|
||||
|
||||
@@ -50,7 +50,9 @@ export function TimelineStepComposer({
|
||||
header={result.status}
|
||||
isExpanded={result.isExpanded}
|
||||
onToggle={result.onToggle}
|
||||
collapsible={collapsible && !isSingleStep}
|
||||
collapsible={
|
||||
collapsible && (!isSingleStep || !!result.alwaysCollapsible)
|
||||
}
|
||||
supportsCollapsible={result.supportsCollapsible}
|
||||
isLastStep={index === results.length - 1 && isLastStep}
|
||||
isFirstStep={index === 0 && isFirstStep}
|
||||
|
||||
@@ -54,7 +54,7 @@ export function TimelineRow({
|
||||
isHover={isHover}
|
||||
/>
|
||||
)}
|
||||
<div className="flex-1">{children}</div>
|
||||
<div className="flex-1 min-w-0">{children}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
{stdout && (
|
||||
<div className="rounded-md bg-background-neutral-02 p-3">
|
||||
<div className="text-xs font-semibold mb-1 text-text-03">Output:</div>
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-text-01">
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-text-01 overflow-x-auto">
|
||||
{stdout}
|
||||
</pre>
|
||||
</div>
|
||||
@@ -150,7 +150,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
<div className="text-xs font-semibold mb-1 text-status-error-05">
|
||||
Error:
|
||||
</div>
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-status-error-05">
|
||||
<pre className="text-sm whitespace-pre-wrap font-mono text-status-error-05 overflow-x-auto">
|
||||
{stderr}
|
||||
</pre>
|
||||
</div>
|
||||
@@ -181,6 +181,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
status,
|
||||
content,
|
||||
supportsCollapsible: true,
|
||||
alwaysCollapsible: true,
|
||||
},
|
||||
]);
|
||||
}
|
||||
@@ -191,6 +192,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
icon: SvgTerminal,
|
||||
status,
|
||||
supportsCollapsible: true,
|
||||
alwaysCollapsible: true,
|
||||
content: (
|
||||
<FadingEdgeContainer
|
||||
direction="bottom"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
:root {
|
||||
--app-page-main-content-width: 52.5rem;
|
||||
--block-width-form-input-min: 10rem;
|
||||
}
|
||||
|
||||
@@ -427,7 +427,7 @@ export const GroupDisplay = ({
|
||||
|
||||
<Separator />
|
||||
|
||||
<h2 className="text-xl font-bold mt-8 mb-2">Assistants</h2>
|
||||
<h2 className="text-xl font-bold mt-8 mb-2">Agents</h2>
|
||||
|
||||
<div>
|
||||
{userGroup.document_sets.length > 0 ? (
|
||||
@@ -445,7 +445,7 @@ export const GroupDisplay = ({
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<Text>No Assistants in this group...</Text>
|
||||
<Text>No Agents in this group...</Text>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -152,14 +152,14 @@ export function PersonaMessagesChart({
|
||||
} else if (selectedPersonaId === undefined) {
|
||||
content = (
|
||||
<div className="h-80 text-text-500 flex flex-col">
|
||||
<p className="m-auto">Select an assistant to view analytics</p>
|
||||
<p className="m-auto">Select an agent to view analytics</p>
|
||||
</div>
|
||||
);
|
||||
} else if (!personaMessagesData?.length) {
|
||||
content = (
|
||||
<div className="h-80 text-text-500 flex flex-col">
|
||||
<p className="m-auto">
|
||||
No data found for selected assistant in the specified time range
|
||||
No data found for selected agent in the specified time range
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
@@ -178,11 +178,9 @@ export function PersonaMessagesChart({
|
||||
|
||||
return (
|
||||
<CardSection className="mt-8">
|
||||
<Title>Assistant Analytics</Title>
|
||||
<Title>Agent Analytics</Title>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text>
|
||||
Messages and unique users per day for the selected assistant
|
||||
</Text>
|
||||
<Text>Messages and unique users per day for the selected agent</Text>
|
||||
<div className="flex items-center gap-4">
|
||||
<Select
|
||||
value={selectedPersonaId?.toString() ?? ""}
|
||||
@@ -191,14 +189,14 @@ export function PersonaMessagesChart({
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="flex w-full max-w-xs">
|
||||
<SelectValue placeholder="Select an assistant to display" />
|
||||
<SelectValue placeholder="Select an agent to display" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<div className="flex items-center px-2 pb-2 sticky top-0 bg-background border-b">
|
||||
<Search className="h-4 w-4 mr-2 shrink-0 opacity-50" />
|
||||
<input
|
||||
className="flex h-8 w-full rounded-sm bg-transparent py-3 text-sm outline-none placeholder:text-muted-foreground disabled:cursor-not-allowed disabled:opacity-50"
|
||||
placeholder="Search assistants..."
|
||||
placeholder="Search agents..."
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
|
||||
@@ -146,7 +146,7 @@ export function AssistantStats({ assistantId }: { assistantId: number }) {
|
||||
return (
|
||||
<Card className="w-full">
|
||||
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||
<p className="text-base font-normal text-2xl">Assistant Analytics</p>
|
||||
<p className="text-base font-normal text-2xl">Agent Analytics</p>
|
||||
<AdminDateRangeSelector
|
||||
value={dateRange}
|
||||
onValueChange={setDateRange}
|
||||
|
||||
@@ -12,17 +12,17 @@ export default function NoAssistantModal() {
|
||||
return (
|
||||
<Modal open>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header icon={SvgUser} title="No Assistant Available" />
|
||||
<Modal.Header icon={SvgUser} title="No Agent Available" />
|
||||
<Modal.Body>
|
||||
<Text as="p">
|
||||
You currently have no assistant configured. To use this feature, you
|
||||
You currently have no agent configured. To use this feature, you
|
||||
need to take action.
|
||||
</Text>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<Text as="p">
|
||||
As an administrator, you can create a new assistant by visiting
|
||||
the admin panel.
|
||||
As an administrator, you can create a new agent by visiting the
|
||||
admin panel.
|
||||
</Text>
|
||||
<Button className="w-full" href="/admin/assistants">
|
||||
Go to Admin Panel
|
||||
@@ -30,8 +30,7 @@ export default function NoAssistantModal() {
|
||||
</>
|
||||
) : (
|
||||
<Text as="p">
|
||||
Please contact your administrator to configure an assistant for
|
||||
you.
|
||||
Please contact your administrator to configure an agent for you.
|
||||
</Text>
|
||||
)}
|
||||
</Modal.Body>
|
||||
|
||||
44
web/src/hooks/useCodeInterpreter.ts
Normal file
44
web/src/hooks/useCodeInterpreter.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
const HEALTH_ENDPOINT = "/api/admin/code-interpreter/health";
|
||||
const STATUS_ENDPOINT = "/api/admin/code-interpreter";
|
||||
|
||||
interface CodeInterpreterHealth {
|
||||
healthy: boolean;
|
||||
}
|
||||
|
||||
interface CodeInterpreterStatus {
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export default function useCodeInterpreter() {
|
||||
const {
|
||||
data: healthData,
|
||||
error: healthError,
|
||||
isLoading: isHealthLoading,
|
||||
mutate: refetchHealth,
|
||||
} = useSWR<CodeInterpreterHealth>(HEALTH_ENDPOINT, errorHandlingFetcher, {
|
||||
refreshInterval: 30000,
|
||||
});
|
||||
|
||||
const {
|
||||
data: statusData,
|
||||
error: statusError,
|
||||
isLoading: isStatusLoading,
|
||||
mutate: refetchStatus,
|
||||
} = useSWR<CodeInterpreterStatus>(STATUS_ENDPOINT, errorHandlingFetcher);
|
||||
|
||||
function refetch() {
|
||||
refetchHealth();
|
||||
refetchStatus();
|
||||
}
|
||||
|
||||
return {
|
||||
isHealthy: healthData?.healthy ?? false,
|
||||
isEnabled: statusData?.enabled ?? false,
|
||||
isLoading: isHealthLoading || isStatusLoading,
|
||||
error: healthError || statusError,
|
||||
refetch,
|
||||
};
|
||||
}
|
||||
15
web/src/lib/admin/code-interpreter/svc.ts
Normal file
15
web/src/lib/admin/code-interpreter/svc.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
const UPDATE_ENDPOINT = "/api/admin/code-interpreter";
|
||||
|
||||
interface CodeInterpreterUpdateRequest {
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export async function updateCodeInterpreter(
|
||||
request: CodeInterpreterUpdateRequest
|
||||
): Promise<Response> {
|
||||
return fetch(UPDATE_ENDPOINT, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
}
|
||||
@@ -153,7 +153,7 @@ function InputSelectRoot({
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={cn("w-full relative")}>
|
||||
<div className="w-full min-w-[var(--block-width-form-input-min)] relative">
|
||||
<InputSelectContext.Provider value={contextValue}>
|
||||
<SelectPrimitive.Root
|
||||
{...(isControlled ? { value: currentValue } : { defaultValue })}
|
||||
|
||||
@@ -425,7 +425,7 @@ export default function AgentsNavigationPage() {
|
||||
>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgOnyxOctagon}
|
||||
title="Agents & Assistants"
|
||||
title="Agents"
|
||||
description="Customize AI behavior and knowledge for you and your team's use cases."
|
||||
rightChildren={
|
||||
<Button
|
||||
|
||||
241
web/src/refresh-pages/admin/CodeInterpreterPage.tsx
Normal file
241
web/src/refresh-pages/admin/CodeInterpreterPage.tsx
Normal file
@@ -0,0 +1,241 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Card, type CardProps } from "@/refresh-components/cards";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgCheckCircle,
|
||||
SvgRefreshCw,
|
||||
SvgTerminal,
|
||||
SvgUnplug,
|
||||
SvgXOctagon,
|
||||
} from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import useCodeInterpreter from "@/hooks/useCodeInterpreter";
|
||||
import { updateCodeInterpreter } from "@/lib/admin/code-interpreter/svc";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
interface CodeInterpreterCardProps {
|
||||
variant?: CardProps["variant"];
|
||||
title: string;
|
||||
middleText?: string;
|
||||
strikethrough?: boolean;
|
||||
rightContent: React.ReactNode;
|
||||
}
|
||||
|
||||
function CodeInterpreterCard({
|
||||
variant,
|
||||
title,
|
||||
middleText,
|
||||
strikethrough,
|
||||
rightContent,
|
||||
}: CodeInterpreterCardProps) {
|
||||
return (
|
||||
// TODO (@raunakab): Allow Content to accept strikethrough and middleText
|
||||
<Card variant={variant} padding={0.5}>
|
||||
<ContentAction
|
||||
icon={SvgTerminal}
|
||||
title={middleText ? `${title} ${middleText}` : title}
|
||||
description="Built-in Python runtime"
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
rightChildren={rightContent}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
function CheckingStatus() {
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="end"
|
||||
alignItems="center"
|
||||
gap={0.25}
|
||||
padding={0.5}
|
||||
>
|
||||
<Text mainUiAction text03>
|
||||
Checking...
|
||||
</Text>
|
||||
<SimpleLoader />
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
interface ConnectionStatusProps {
|
||||
healthy: boolean;
|
||||
isLoading: boolean;
|
||||
}
|
||||
|
||||
function ConnectionStatus({ healthy, isLoading }: ConnectionStatusProps) {
|
||||
if (isLoading) {
|
||||
return <CheckingStatus />;
|
||||
}
|
||||
|
||||
const label = healthy ? "Connected" : "Connection Lost";
|
||||
const Icon = healthy ? SvgCheckCircle : SvgXOctagon;
|
||||
const iconColor = healthy ? "text-status-success-05" : "text-status-error-05";
|
||||
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="end"
|
||||
alignItems="center"
|
||||
gap={0.25}
|
||||
padding={0.5}
|
||||
>
|
||||
<Text mainUiAction text03>
|
||||
{label}
|
||||
</Text>
|
||||
<Icon size={16} className={iconColor} />
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
interface ActionButtonsProps {
|
||||
onDisconnect: () => void;
|
||||
onRefresh: () => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
function ActionButtons({
|
||||
onDisconnect,
|
||||
onRefresh,
|
||||
disabled,
|
||||
}: ActionButtonsProps) {
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="end"
|
||||
alignItems="center"
|
||||
gap={0.25}
|
||||
padding={0.25}
|
||||
>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgUnplug}
|
||||
onClick={onDisconnect}
|
||||
tooltip="Disconnect"
|
||||
disabled={disabled}
|
||||
/>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
icon={SvgRefreshCw}
|
||||
onClick={onRefresh}
|
||||
tooltip="Refresh"
|
||||
disabled={disabled}
|
||||
/>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
export default function CodeInterpreterPage() {
|
||||
const { isHealthy, isEnabled, isLoading, refetch } = useCodeInterpreter();
|
||||
const [showDisconnectModal, setShowDisconnectModal] = useState(false);
|
||||
const [isReconnecting, setIsReconnecting] = useState(false);
|
||||
|
||||
async function handleToggle(enabled: boolean) {
|
||||
const action = enabled ? "reconnect" : "disconnect";
|
||||
setIsReconnecting(enabled);
|
||||
try {
|
||||
const response = await updateCodeInterpreter({ enabled });
|
||||
if (!response.ok) {
|
||||
toast.error(`Failed to ${action} Code Interpreter`);
|
||||
return;
|
||||
}
|
||||
setShowDisconnectModal(false);
|
||||
refetch();
|
||||
} finally {
|
||||
setIsReconnecting(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgTerminal}
|
||||
title="Code Interpreter"
|
||||
description="Safe and sandboxed Python runtime available to your LLM. See docs for more details."
|
||||
separator
|
||||
/>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
{isEnabled || isLoading ? (
|
||||
<CodeInterpreterCard
|
||||
title="Code Interpreter"
|
||||
variant={isHealthy ? "primary" : "secondary"}
|
||||
strikethrough={!isHealthy}
|
||||
rightContent={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
justifyContent="center"
|
||||
alignItems="end"
|
||||
gap={0}
|
||||
padding={0}
|
||||
>
|
||||
<ConnectionStatus healthy={isHealthy} isLoading={isLoading} />
|
||||
<ActionButtons
|
||||
onDisconnect={() => setShowDisconnectModal(true)}
|
||||
onRefresh={refetch}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<CodeInterpreterCard
|
||||
variant="secondary"
|
||||
title="Code Interpreter"
|
||||
middleText="(Disconnected)"
|
||||
strikethrough={true}
|
||||
rightContent={
|
||||
<Section flexDirection="row" alignItems="center" padding={0.5}>
|
||||
{isReconnecting ? (
|
||||
<CheckingStatus />
|
||||
) : (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgArrowExchange}
|
||||
onClick={() => handleToggle(true)}
|
||||
>
|
||||
Reconnect
|
||||
</Button>
|
||||
)}
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</SettingsLayouts.Body>
|
||||
|
||||
{showDisconnectModal && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title="Disconnect Code Interpreter"
|
||||
onClose={() => setShowDisconnectModal(false)}
|
||||
submit={
|
||||
<Button variant="danger" onClick={() => handleToggle(false)}>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
<Text as="p" text03>
|
||||
All running sessions connected to{" "}
|
||||
<Text as="span" mainContentEmphasis text03>
|
||||
Code Interpreter
|
||||
</Text>{" "}
|
||||
will stop working. Note that this will not remove any data from your
|
||||
runtime. You can reconnect to this runtime later if needed.
|
||||
</Text>
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
@@ -119,7 +119,7 @@ export default function NewTenantModal({
|
||||
: `Your request to join ${tenantInfo.number_of_users} other users of ${APP_DOMAIN} has been approved.`;
|
||||
|
||||
const description = isInvite
|
||||
? `By accepting this invitation, you will join the existing ${APP_DOMAIN} team and lose access to your current team. Note: you will lose access to your current assistants, prompts, chats, and connected sources.`
|
||||
? `By accepting this invitation, you will join the existing ${APP_DOMAIN} team and lose access to your current team. Note: you will lose access to your current agents, prompts, chats, and connected sources.`
|
||||
: `To finish joining your team, please reauthenticate with ${user?.email}.`;
|
||||
|
||||
return (
|
||||
|
||||
@@ -50,6 +50,7 @@ import {
|
||||
SvgPaintBrush,
|
||||
SvgDiscordMono,
|
||||
SvgWallet,
|
||||
SvgTerminal,
|
||||
} from "@opal/icons";
|
||||
import SvgMcp from "@opal/icons/mcp";
|
||||
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
|
||||
@@ -91,7 +92,7 @@ const custom_assistants_items = (
|
||||
) => {
|
||||
const items = [
|
||||
{
|
||||
name: "Assistants",
|
||||
name: "Agents",
|
||||
icon: SvgOnyxOctagon,
|
||||
link: "/admin/assistants",
|
||||
},
|
||||
@@ -165,7 +166,7 @@ const collections = (
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: "Custom Assistants",
|
||||
name: "Custom Agents",
|
||||
items: custom_assistants_items(isCurator, enableEnterprise),
|
||||
},
|
||||
...(isCurator && enableEnterprise
|
||||
@@ -207,6 +208,11 @@ const collections = (
|
||||
icon: SvgImage,
|
||||
link: "/admin/configuration/image-generation",
|
||||
},
|
||||
{
|
||||
name: "Code Interpreter",
|
||||
icon: SvgTerminal,
|
||||
link: "/admin/configuration/code-interpreter",
|
||||
},
|
||||
...(!enableCloud && vectorDbEnabled
|
||||
? [
|
||||
{
|
||||
|
||||
@@ -13,6 +13,7 @@ interface LogoSectionProps {
|
||||
function LogoSection({ folded, onFoldClick }: LogoSectionProps) {
|
||||
const settings = useSettingsContext();
|
||||
const applicationName = settings.enterpriseSettings?.application_name;
|
||||
const logoDisplayStyle = settings.enterpriseSettings?.logo_display_style;
|
||||
|
||||
const logo = useCallback(
|
||||
(className?: string) => <Logo folded={folded} className={className} />,
|
||||
@@ -43,7 +44,7 @@ function LogoSection({ folded, onFoldClick }: LogoSectionProps) {
|
||||
>
|
||||
{folded === undefined ? (
|
||||
<div className="p-1">{logo()}</div>
|
||||
) : folded ? (
|
||||
) : folded && logoDisplayStyle !== "name_only" ? (
|
||||
<>
|
||||
<div className="group-hover/SidebarWrapper:hidden pt-1.5">
|
||||
{logo()}
|
||||
@@ -52,6 +53,8 @@ function LogoSection({ folded, onFoldClick }: LogoSectionProps) {
|
||||
{closeButton(false)}
|
||||
</div>
|
||||
</>
|
||||
) : folded ? (
|
||||
<div className="flex w-full justify-center">{closeButton(false)}</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="p-1"> {logo()}</div>
|
||||
|
||||
@@ -29,12 +29,12 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
pageTitle: "Add Connector",
|
||||
},
|
||||
{
|
||||
name: "Custom Assistants - Assistants",
|
||||
name: "Custom Agents - Agents",
|
||||
path: "assistants",
|
||||
pageTitle: "Assistants",
|
||||
pageTitle: "Agents",
|
||||
options: {
|
||||
paragraphText:
|
||||
"Assistants are a way to build custom search/question-answering experiences for different use cases.",
|
||||
"Agents are a way to build custom search/question-answering experiences for different use cases.",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -52,7 +52,7 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom Assistants - Slack Bots",
|
||||
name: "Custom Agents - Slack Bots",
|
||||
path: "bots",
|
||||
pageTitle: "Slack Bots",
|
||||
options: {
|
||||
@@ -61,7 +61,7 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom Assistants - Standard Answers",
|
||||
name: "Custom Agents - Standard Answers",
|
||||
path: "standard-answer",
|
||||
pageTitle: "Standard Answers",
|
||||
},
|
||||
@@ -101,12 +101,12 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
pageTitle: "Search Settings",
|
||||
},
|
||||
{
|
||||
name: "Custom Assistants - MCP Actions",
|
||||
name: "Custom Agents - MCP Actions",
|
||||
path: "actions/mcp",
|
||||
pageTitle: "MCP Actions",
|
||||
},
|
||||
{
|
||||
name: "Custom Assistants - OpenAPI Actions",
|
||||
name: "Custom Agents - OpenAPI Actions",
|
||||
path: "actions/open-api",
|
||||
pageTitle: "OpenAPI Actions",
|
||||
},
|
||||
|
||||
268
web/tests/e2e/admin/code-interpreter/code_interpreter.spec.ts
Normal file
268
web/tests/e2e/admin/code-interpreter/code_interpreter.spec.ts
Normal file
@@ -0,0 +1,268 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import type { Page } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
const CODE_INTERPRETER_URL = "/admin/configuration/code-interpreter";
|
||||
const API_STATUS_URL = "**/api/admin/code-interpreter";
|
||||
const API_HEALTH_URL = "**/api/admin/code-interpreter/health";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Intercept the status (GET /) and health (GET /health) endpoints with the
|
||||
* given values so the page renders deterministically.
|
||||
*
|
||||
* Also handles PUT requests — by default they succeed (200). Pass
|
||||
* `putStatus` to simulate failures.
|
||||
*/
|
||||
async function mockCodeInterpreterApi(
|
||||
page: Page,
|
||||
opts: { enabled: boolean; healthy: boolean; putStatus?: number }
|
||||
) {
|
||||
const putStatus = opts.putStatus ?? 200;
|
||||
|
||||
await page.route(API_HEALTH_URL, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ healthy: opts.healthy }),
|
||||
});
|
||||
});
|
||||
|
||||
await page.route(API_STATUS_URL, async (route) => {
|
||||
if (route.request().method() === "PUT") {
|
||||
await route.fulfill({
|
||||
status: putStatus,
|
||||
contentType: "application/json",
|
||||
body:
|
||||
putStatus >= 400
|
||||
? JSON.stringify({ detail: "Server Error" })
|
||||
: JSON.stringify(null),
|
||||
});
|
||||
} else {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ enabled: opts.enabled }),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* The disconnect icon button is an icon-only opal Button whose tooltip text
|
||||
* is not exposed as an accessible name. Locate it by finding the first
|
||||
* icon-only button (no label span) inside the card area.
|
||||
*/
|
||||
function getDisconnectIconButton(page: Page) {
|
||||
return page
|
||||
.locator("button:has(.opal-button):not(:has(.opal-button-label))")
|
||||
.first();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test.describe("Code Interpreter Admin Page", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
});
|
||||
|
||||
test("page loads with header and description", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
|
||||
/^Code Interpreter/,
|
||||
{ timeout: 10000 }
|
||||
);
|
||||
|
||||
await expect(page.getByText("Built-in Python runtime")).toBeVisible();
|
||||
});
|
||||
|
||||
test("shows Connected status when enabled and healthy", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
});
|
||||
|
||||
test("shows Connection Lost when enabled but unhealthy", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connection Lost")).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
});
|
||||
|
||||
test("shows Reconnect button when disabled", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: false, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
await expect(page.getByText("(Disconnected)")).toBeVisible();
|
||||
});
|
||||
|
||||
test("disconnect flow opens modal and sends PUT request", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Click the disconnect icon button
|
||||
await getDisconnectIconButton(page).click();
|
||||
|
||||
// Modal should appear
|
||||
await expect(page.getByText("Disconnect Code Interpreter")).toBeVisible();
|
||||
await expect(
|
||||
page.getByText("All running sessions connected to")
|
||||
).toBeVisible();
|
||||
|
||||
// Click the danger Disconnect button in the modal
|
||||
const modal = page.getByRole("dialog");
|
||||
await modal.getByRole("button", { name: "Disconnect" }).click();
|
||||
|
||||
// Modal should close after successful disconnect
|
||||
await expect(page.getByText("Disconnect Code Interpreter")).not.toBeVisible(
|
||||
{ timeout: 5000 }
|
||||
);
|
||||
});
|
||||
|
||||
test("disconnect modal can be closed without disconnecting", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Open modal
|
||||
await getDisconnectIconButton(page).click();
|
||||
await expect(page.getByText("Disconnect Code Interpreter")).toBeVisible();
|
||||
|
||||
// Close modal via Cancel button
|
||||
const modal = page.getByRole("dialog");
|
||||
await modal.getByRole("button", { name: "Cancel" }).click();
|
||||
|
||||
// Modal should be gone, page still shows Connected
|
||||
await expect(
|
||||
page.getByText("Disconnect Code Interpreter")
|
||||
).not.toBeVisible();
|
||||
await expect(page.getByText("Connected")).toBeVisible();
|
||||
});
|
||||
|
||||
test("reconnect flow sends PUT with enabled=true", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: false, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
// Intercept the PUT and verify the payload
|
||||
const putPromise = page.waitForRequest(
|
||||
(req) =>
|
||||
req.url().includes("/api/admin/code-interpreter") &&
|
||||
req.method() === "PUT"
|
||||
);
|
||||
|
||||
await page.getByRole("button", { name: "Reconnect" }).click();
|
||||
|
||||
const putReq = await putPromise;
|
||||
expect(putReq.postDataJSON()).toEqual({ enabled: true });
|
||||
});
|
||||
|
||||
test("shows Checking... while reconnect is in progress", async ({ page }) => {
|
||||
// Use a single route handler that delays PUT responses
|
||||
await page.route(API_HEALTH_URL, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ healthy: false }),
|
||||
});
|
||||
});
|
||||
|
||||
await page.route(API_STATUS_URL, async (route) => {
|
||||
if (route.request().method() === "PUT") {
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify(null),
|
||||
});
|
||||
} else {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ enabled: false }),
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
await page.getByRole("button", { name: "Reconnect" }).click();
|
||||
|
||||
// Should show Checking... while the request is in flight
|
||||
await expect(page.getByText("Checking...")).toBeVisible({ timeout: 3000 });
|
||||
});
|
||||
|
||||
test("shows error toast when disconnect fails", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: true,
|
||||
healthy: true,
|
||||
putStatus: 500,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Open modal and click disconnect
|
||||
await getDisconnectIconButton(page).click();
|
||||
const modal = page.getByRole("dialog");
|
||||
await modal.getByRole("button", { name: "Disconnect" }).click();
|
||||
|
||||
// Error toast should appear
|
||||
await expect(
|
||||
page.getByText("Failed to disconnect Code Interpreter")
|
||||
).toBeVisible({ timeout: 5000 });
|
||||
});
|
||||
|
||||
test("shows error toast when reconnect fails", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: false,
|
||||
healthy: false,
|
||||
putStatus: 500,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
await page.getByRole("button", { name: "Reconnect" }).click();
|
||||
|
||||
// Error toast should appear
|
||||
await expect(
|
||||
page.getByText("Failed to reconnect Code Interpreter")
|
||||
).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Reconnect button should reappear (not stuck in Checking...)
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -46,7 +46,7 @@ test.skip("User changes password and logs in with new password", async ({
|
||||
|
||||
// Verify successful login
|
||||
await expect(page).toHaveURL("http://localhost:3000/app");
|
||||
await expect(page.getByText("Explore Assistants")).toBeVisible();
|
||||
await expect(page.getByText("Explore Agents")).toBeVisible();
|
||||
});
|
||||
|
||||
test.use({ storageState: "admin2_auth.json" });
|
||||
@@ -115,5 +115,5 @@ test.skip("Admin resets own password and logs in with new password", async ({
|
||||
|
||||
// Verify successful login
|
||||
await expect(page).toHaveURL("http://localhost:3000/app");
|
||||
await expect(page.getByText("Explore Assistants")).toBeVisible();
|
||||
await expect(page.getByText("Explore Agents")).toBeVisible();
|
||||
});
|
||||
|
||||
@@ -16,7 +16,7 @@ import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
|
||||
// Tool-related test selectors now imported from shared utils
|
||||
|
||||
test.describe("Default Assistant Tests", () => {
|
||||
test.describe("Default Agent Tests", () => {
|
||||
let imageGenConfigId: string | null = null;
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
@@ -69,7 +69,7 @@ test.describe("Default Assistant Tests", () => {
|
||||
});
|
||||
|
||||
test.describe("Greeting Message Display", () => {
|
||||
test("should display greeting message when opening new chat with default assistant", async ({
|
||||
test("should display greeting message when opening new chat with default agent", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Look for greeting message - should be one from the predefined list
|
||||
@@ -95,23 +95,21 @@ test.describe("Default Assistant Tests", () => {
|
||||
expect(GREETING_MESSAGES).toContain(greetingAfterReload?.trim());
|
||||
});
|
||||
|
||||
test("greeting should only appear for default assistant", async ({
|
||||
page,
|
||||
}) => {
|
||||
// First verify greeting appears for default assistant
|
||||
test("greeting should only appear for default agent", async ({ page }) => {
|
||||
// First verify greeting appears for default agent
|
||||
const greetingElement = await page.waitForSelector(
|
||||
'[data-testid="onyx-logo"]',
|
||||
{ timeout: 5000 }
|
||||
);
|
||||
expect(greetingElement).toBeTruthy();
|
||||
|
||||
// Create a custom assistant to test non-default behavior
|
||||
// Create a custom agent to test non-default behavior
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
await page.locator('input[name="name"]').fill("Custom Test Assistant");
|
||||
await page.locator('input[name="name"]').fill("Custom Test Agent");
|
||||
await page
|
||||
.locator('textarea[name="description"]')
|
||||
.fill("Test Description");
|
||||
@@ -120,17 +118,17 @@ test.describe("Default Assistant Tests", () => {
|
||||
.fill("Test Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Wait for assistant to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Custom Test Assistant");
|
||||
// Wait for agent to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Custom Test Agent");
|
||||
|
||||
// Greeting should NOT appear for custom assistant
|
||||
// Greeting should NOT appear for custom agent
|
||||
const customGreeting = await page.$('[data-testid="onyx-logo"]');
|
||||
expect(customGreeting).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Default Assistant Branding", () => {
|
||||
test("should display Onyx logo for default assistant", async ({ page }) => {
|
||||
test.describe("Default Agent Branding", () => {
|
||||
test("should display Onyx logo for default agent", async ({ page }) => {
|
||||
// Look for Onyx logo
|
||||
const logoElement = await page.waitForSelector(
|
||||
'[data-testid="onyx-logo"]',
|
||||
@@ -138,23 +136,23 @@ test.describe("Default Assistant Tests", () => {
|
||||
);
|
||||
expect(logoElement).toBeTruthy();
|
||||
|
||||
// Should NOT show assistant name for default assistant
|
||||
// Should NOT show agent name for default agent
|
||||
const assistantNameElement = await page.$(
|
||||
'[data-testid="assistant-name-display"]'
|
||||
);
|
||||
expect(assistantNameElement).toBeNull();
|
||||
});
|
||||
|
||||
test("custom assistants should show name and icon instead of logo", async ({
|
||||
test("custom agents should show name and icon instead of logo", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Create a custom assistant
|
||||
// Create a custom agent
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
await page.locator('input[name="name"]').fill("Custom Assistant");
|
||||
await page.locator('input[name="name"]').fill("Custom Agent");
|
||||
await page
|
||||
.locator('textarea[name="description"]')
|
||||
.fill("Test Description");
|
||||
@@ -163,16 +161,16 @@ test.describe("Default Assistant Tests", () => {
|
||||
.fill("Test Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Wait for assistant to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Custom Assistant");
|
||||
// Wait for agent to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Custom Agent");
|
||||
|
||||
// Should show assistant name and icon, not Onyx logo
|
||||
// Should show agent name and icon, not Onyx logo
|
||||
const assistantNameElement = await page.waitForSelector(
|
||||
'[data-testid="assistant-name-display"]',
|
||||
{ timeout: 5000 }
|
||||
);
|
||||
const nameText = await assistantNameElement.textContent();
|
||||
expect(nameText).toContain("Custom Assistant");
|
||||
expect(nameText).toContain("Custom Agent");
|
||||
|
||||
// Onyx logo should NOT be shown
|
||||
const logoElement = await page.$('[data-testid="onyx-logo"]');
|
||||
@@ -181,10 +179,8 @@ test.describe("Default Assistant Tests", () => {
|
||||
});
|
||||
|
||||
test.describe("Starter Messages", () => {
|
||||
test("default assistant should NOT have starter messages", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Check that starter messages container does not exist for default assistant
|
||||
test("default agent should NOT have starter messages", async ({ page }) => {
|
||||
// Check that starter messages container does not exist for default agent
|
||||
const starterMessagesContainer = await page.$(
|
||||
'[data-testid="starter-messages"]'
|
||||
);
|
||||
@@ -195,18 +191,14 @@ test.describe("Default Assistant Tests", () => {
|
||||
expect(starterButtons.length).toBe(0);
|
||||
});
|
||||
|
||||
test("custom assistants should display starter messages", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Create a custom assistant with starter messages
|
||||
test("custom agents should display starter messages", async ({ page }) => {
|
||||
// Create a custom agent with starter messages
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.fill("Test Assistant with Starters");
|
||||
await page.locator('input[name="name"]').fill("Test Agent with Starters");
|
||||
await page
|
||||
.locator('textarea[name="description"]')
|
||||
.fill("Test Description");
|
||||
@@ -219,9 +211,9 @@ test.describe("Default Assistant Tests", () => {
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Wait for assistant to be created and selected
|
||||
await verifyAssistantIsChosen(page, "Test Assistant with Starters");
|
||||
await verifyAssistantIsChosen(page, "Test Agent with Starters");
|
||||
|
||||
// Starter messages container might exist but be empty for custom assistants
|
||||
// Starter messages container might exist but be empty for custom agents
|
||||
const starterMessagesContainer = await page.$(
|
||||
'[data-testid="starter-messages"]'
|
||||
);
|
||||
@@ -230,24 +222,22 @@ test.describe("Default Assistant Tests", () => {
|
||||
const starterButtons = await page.$$(
|
||||
'[data-testid^="starter-message-"]'
|
||||
);
|
||||
// Custom assistant without configured starter messages should have none
|
||||
// Custom agent without configured starter messages should have none
|
||||
expect(starterButtons.length).toBe(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Assistant Selection", () => {
|
||||
test("default assistant should be selected for new chats", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Verify the input placeholder indicates default assistant (Onyx)
|
||||
test.describe("Agent Selection", () => {
|
||||
test("default agent should be selected for new chats", async ({ page }) => {
|
||||
// Verify the input placeholder indicates default agent (Onyx)
|
||||
await verifyDefaultAssistantIsChosen(page);
|
||||
});
|
||||
|
||||
test("default assistant should NOT appear in assistant selector", async ({
|
||||
test("default agent should NOT appear in agent selector", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Open assistant selector
|
||||
// Open agent selector
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
|
||||
// Wait for modal or assistant list to appear
|
||||
@@ -256,13 +246,13 @@ test.describe("Default Assistant Tests", () => {
|
||||
.getByLabel("AgentsPage/new-agent-button")
|
||||
.waitFor({ state: "visible", timeout: 5000 });
|
||||
|
||||
// Look for default assistant by name - it should NOT be there
|
||||
// Look for default agent by name - it should NOT be there
|
||||
const assistantElements = await page.$$('[data-testid^="assistant-"]');
|
||||
const assistantTexts = await Promise.all(
|
||||
assistantElements.map((el) => el.textContent())
|
||||
);
|
||||
|
||||
// Check that "Assistant" (the default assistant name) is not in the list
|
||||
// Check that the default agent is not in the list
|
||||
const hasDefaultAssistant = assistantTexts.some(
|
||||
(text) =>
|
||||
text?.includes("Assistant") &&
|
||||
@@ -275,16 +265,16 @@ test.describe("Default Assistant Tests", () => {
|
||||
await page.keyboard.press("Escape");
|
||||
});
|
||||
|
||||
test("should be able to switch from default to custom assistant", async ({
|
||||
test("should be able to switch from default to custom agent", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Create a custom assistant
|
||||
// Create a custom agent
|
||||
await page.getByTestId("AppSidebar/more-agents").click();
|
||||
await page.getByLabel("AgentsPage/new-agent-button").click();
|
||||
await page
|
||||
.locator('input[name="name"]')
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
await page.locator('input[name="name"]').fill("Switch Test Assistant");
|
||||
await page.locator('input[name="name"]').fill("Switch Test Agent");
|
||||
await page
|
||||
.locator('textarea[name="description"]')
|
||||
.fill("Test Description");
|
||||
@@ -293,13 +283,13 @@ test.describe("Default Assistant Tests", () => {
|
||||
.fill("Test Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Verify switched to custom assistant
|
||||
await verifyAssistantIsChosen(page, "Switch Test Assistant");
|
||||
// Verify switched to custom agent
|
||||
await verifyAssistantIsChosen(page, "Switch Test Agent");
|
||||
|
||||
// Start new chat to go back to default
|
||||
await startNewChat(page);
|
||||
|
||||
// Should be back to default assistant
|
||||
// Should be back to default agent
|
||||
await verifyDefaultAssistantIsChosen(page);
|
||||
});
|
||||
});
|
||||
@@ -379,7 +369,7 @@ test.describe("Default Assistant Tests", () => {
|
||||
);
|
||||
}
|
||||
|
||||
// Enable the tools in default assistant config via API
|
||||
// Enable the tools in default agent config via API
|
||||
// Get current tools to find their IDs
|
||||
const toolsListResp = await page.request.get(
|
||||
"http://localhost:3000/api/tool"
|
||||
@@ -542,7 +532,7 @@ test.describe("Default Assistant Tests", () => {
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("End-to-End Default Assistant Flow", () => {
|
||||
test.describe("End-to-End Default Agent Flow", () => {
|
||||
let imageGenConfigId: string | null = null;
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
@@ -584,7 +574,7 @@ test.describe("End-to-End Default Assistant Flow", () => {
|
||||
}
|
||||
});
|
||||
|
||||
test("complete user journey with default assistant", async ({ page }) => {
|
||||
test("complete user journey with default agent", async ({ page }) => {
|
||||
// Clear cookies and log in as a random user
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
@@ -611,7 +601,7 @@ test.describe("End-to-End Default Assistant Flow", () => {
|
||||
// Start a new chat
|
||||
await startNewChat(page);
|
||||
|
||||
// Verify we're back to default assistant with greeting
|
||||
// Verify we're back to default agent with greeting
|
||||
await expect(page.locator('[data-testid="onyx-logo"]')).toBeVisible();
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user