mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-19 06:32:42 +00:00
Compare commits
29 Commits
dane/index
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
710b39074f | ||
|
|
8fe2f67d38 | ||
|
|
f00aaf9fc0 | ||
|
|
5b2426b002 | ||
|
|
ba6ab0245b | ||
|
|
b64ebb57e1 | ||
|
|
2fcfdbabde | ||
|
|
ea1a2749c1 | ||
|
|
73c4e22588 | ||
|
|
fceaac6e13 | ||
|
|
e8bf45cfd2 | ||
|
|
13ff648fcd | ||
|
|
ae8268afb1 | ||
|
|
b338bd9e97 | ||
|
|
0dcc90a042 | ||
|
|
0f6a6693d3 | ||
|
|
e32cc450b2 | ||
|
|
732fb71edf | ||
|
|
ca3320c0e0 | ||
|
|
d7c554aca7 | ||
|
|
69e5c19695 | ||
|
|
b4ce1c7a97 | ||
|
|
cd64a91154 | ||
|
|
c282cdc096 | ||
|
|
b1de1c59b6 | ||
|
|
64d484039f | ||
|
|
0530095b71 | ||
|
|
23280d5b91 | ||
|
|
229442679c |
@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.hooks",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if not MULTI_TENANT and HOOK_ENABLED:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
|
||||
35
backend/onyx/background/celery/tasks/hooks/tasks.py
Normal file
35
backend/onyx/background/celery/tasks/hooks/tasks.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from celery import shared_task
|
||||
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.hook import cleanup_old_execution_logs__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_HOOK_EXECUTION_LOG_RETENTION_DAYS: int = 30
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def hook_execution_log_cleanup_task(*, tenant_id: str) -> None: # noqa: ARG001
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
deleted: int = cleanup_old_execution_logs__no_commit(
|
||||
db_session=db_session,
|
||||
max_age_days=_HOOK_EXECUTION_LOG_RETENTION_DAYS,
|
||||
)
|
||||
db_session.commit()
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Deleted {deleted} hook execution log(s) older than "
|
||||
f"{_HOOK_EXECUTION_LOG_RETENTION_DAYS} days."
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to clean up hook execution logs")
|
||||
raise
|
||||
@@ -24,6 +24,7 @@ from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
@@ -33,6 +34,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
@@ -91,6 +93,17 @@ def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a delete_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
return celery_get_queue_length(
|
||||
@@ -546,7 +559,23 @@ def process_single_user_file(
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with DELETING status and enqueue per-file tasks."""
|
||||
"""Scan for user files with DELETING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway (mirrors check_user_file_processing):
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_DELETE_MAX_QUEUE_DEPTH items we skip this beat cycle entirely.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_DELETE_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still DELETING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_DELETE_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
"""
|
||||
task_logger.info("check_for_user_file_delete - Starting")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
@@ -555,8 +584,23 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_for_user_file_delete - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_DELETE_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -568,23 +612,40 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
.all()
|
||||
)
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_delete_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"check_for_user_file_delete - Error enqueuing deletes - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_for_user_file_delete - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_for_user_file_delete - Enqueued {enqueued} tasks, skipped_guard={skipped_guard} for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -602,6 +663,9 @@ def delete_user_file_impl(
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
# Clear the queued guard so the beat can re-enqueue if deletion fails
|
||||
# and the file remains in DELETING status.
|
||||
redis_client.delete(_user_file_delete_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
|
||||
4
backend/onyx/cache/postgres_backend.py
vendored
4
backend/onyx/cache/postgres_backend.py
vendored
@@ -297,7 +297,9 @@ class PostgresCacheBackend(CacheBackend):
|
||||
|
||||
def _lock_id_for(self, name: str) -> int:
|
||||
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
|
||||
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
|
||||
h = hashlib.md5(
|
||||
f"{self._tenant_id}:{name}".encode(), usedforsecurity=False
|
||||
).digest()
|
||||
return struct.unpack("q", h[:8])[0]
|
||||
|
||||
|
||||
|
||||
@@ -278,14 +278,17 @@ USING_AWS_MANAGED_OPENSEARCH = (
|
||||
OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Whether to disable match highlights for OpenSearch. Defaults to True for now
|
||||
# as we investigate query performance.
|
||||
OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED", "true").lower() == "true"
|
||||
)
|
||||
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
|
||||
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
|
||||
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
|
||||
OPENSEARCH_EXPLAIN_ENABLED = (
|
||||
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
|
||||
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
|
||||
# existing indices need reindexing after a change.
|
||||
@@ -318,8 +321,16 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
|
||||
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
|
||||
)
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
|
||||
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
|
||||
# If set, will override the default number of shards and replicas for the index.
|
||||
OPENSEARCH_INDEX_NUM_SHARDS: int | None = (
|
||||
int(os.environ["OPENSEARCH_INDEX_NUM_SHARDS"])
|
||||
if os.environ.get("OPENSEARCH_INDEX_NUM_SHARDS", None) is not None
|
||||
else None
|
||||
)
|
||||
OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
|
||||
int(os.environ["OPENSEARCH_INDEX_NUM_REPLICAS"])
|
||||
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
|
||||
@@ -177,6 +177,14 @@ USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file-delete task is valid before workers discard it.
|
||||
# Mirrors the processing task expiry to prevent indefinite queue growth when
|
||||
# files are stuck in DELETING status and the beat keeps re-enqueuing them.
|
||||
CELERY_USER_FILE_DELETE_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Max queue depth before the delete beat stops enqueuing more delete tasks.
|
||||
USER_FILE_DELETE_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -469,6 +477,9 @@ class OnyxRedisLocks:
|
||||
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
# Short-lived key set when a delete task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a delete task is already queued.
|
||||
USER_FILE_DELETE_QUEUED_PREFIX = "da_lock:user_file_delete_queued"
|
||||
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
@@ -597,6 +608,9 @@ class OnyxCeleryTask:
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
|
||||
# Hook execution log retention
|
||||
HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task"
|
||||
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
|
||||
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"
|
||||
|
||||
233
backend/onyx/db/hook.py
Normal file
233
backend/onyx/db/hook.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.models import Hook
|
||||
from onyx.db.models import HookExecutionLog
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ── Hook CRUD ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_hook_by_id(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
include_deleted: bool = False,
|
||||
include_creator: bool = False,
|
||||
) -> Hook | None:
|
||||
stmt = select(Hook).where(Hook.id == hook_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Hook.deleted.is_(False))
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def get_non_deleted_hook_by_hook_point(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
include_creator: bool = False,
|
||||
) -> Hook | None:
|
||||
stmt = (
|
||||
select(Hook).where(Hook.hook_point == hook_point).where(Hook.deleted.is_(False))
|
||||
)
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def get_hooks(
|
||||
*,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
include_creator: bool = False,
|
||||
) -> list[Hook]:
|
||||
stmt = select(Hook)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Hook.deleted.is_(False))
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
stmt = stmt.order_by(Hook.hook_point, Hook.created_at.desc())
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def create_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
name: str,
|
||||
hook_point: HookPoint,
|
||||
endpoint_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
fail_strategy: HookFailStrategy,
|
||||
timeout_seconds: float,
|
||||
is_active: bool = False,
|
||||
creator_id: UUID | None = None,
|
||||
) -> Hook:
|
||||
"""Create a new hook for the given hook point.
|
||||
|
||||
At most one non-deleted hook per hook point is allowed. Raises
|
||||
OnyxError(CONFLICT) if a hook already exists, including under concurrent
|
||||
duplicate creates where the partial unique index fires an IntegrityError.
|
||||
"""
|
||||
existing = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if existing:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
f"A hook for '{hook_point.value}' already exists (id={existing.id}).",
|
||||
)
|
||||
|
||||
hook = Hook(
|
||||
name=name,
|
||||
hook_point=hook_point,
|
||||
endpoint_url=endpoint_url,
|
||||
api_key=api_key,
|
||||
fail_strategy=fail_strategy,
|
||||
timeout_seconds=timeout_seconds,
|
||||
is_active=is_active,
|
||||
creator_id=creator_id,
|
||||
)
|
||||
# Use a savepoint so that a failed insert only rolls back this operation,
|
||||
# not the entire outer transaction.
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(hook)
|
||||
savepoint.commit()
|
||||
except IntegrityError as exc:
|
||||
savepoint.rollback()
|
||||
if "ix_hook_one_non_deleted_per_point" in str(exc.orig):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
f"A hook for '{hook_point.value}' already exists.",
|
||||
)
|
||||
raise # re-raise unrelated integrity errors (FK violations, etc.)
|
||||
return hook
|
||||
|
||||
|
||||
def update_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
name: str | None = None,
|
||||
endpoint_url: str | None | UnsetType = UNSET,
|
||||
api_key: str | None | UnsetType = UNSET,
|
||||
fail_strategy: HookFailStrategy | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
is_active: bool | None = None,
|
||||
is_reachable: bool | None = None,
|
||||
include_creator: bool = False,
|
||||
) -> Hook:
|
||||
"""Update hook fields.
|
||||
|
||||
Sentinel conventions:
|
||||
- endpoint_url, api_key: pass UNSET to leave unchanged; pass None to clear.
|
||||
- name, fail_strategy, timeout_seconds, is_active, is_reachable: pass None to leave unchanged.
|
||||
"""
|
||||
hook = get_hook_by_id(
|
||||
db_session=db_session, hook_id=hook_id, include_creator=include_creator
|
||||
)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
|
||||
|
||||
if name is not None:
|
||||
hook.name = name
|
||||
if not isinstance(endpoint_url, UnsetType):
|
||||
hook.endpoint_url = endpoint_url
|
||||
if not isinstance(api_key, UnsetType):
|
||||
hook.api_key = api_key # type: ignore[assignment] # EncryptedString coerces str → SensitiveValue at the ORM level
|
||||
if fail_strategy is not None:
|
||||
hook.fail_strategy = fail_strategy
|
||||
if timeout_seconds is not None:
|
||||
hook.timeout_seconds = timeout_seconds
|
||||
if is_active is not None:
|
||||
hook.is_active = is_active
|
||||
if is_reachable is not None:
|
||||
hook.is_reachable = is_reachable
|
||||
|
||||
db_session.flush()
|
||||
return hook
|
||||
|
||||
|
||||
def delete_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
) -> None:
|
||||
hook = get_hook_by_id(db_session=db_session, hook_id=hook_id)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
|
||||
|
||||
hook.deleted = True
|
||||
hook.is_active = False
|
||||
db_session.flush()
|
||||
|
||||
|
||||
# ── HookExecutionLog CRUD ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_hook_execution_log__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
is_success: bool,
|
||||
error_message: str | None = None,
|
||||
status_code: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> HookExecutionLog:
|
||||
log = HookExecutionLog(
|
||||
hook_id=hook_id,
|
||||
is_success=is_success,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
db_session.add(log)
|
||||
db_session.flush()
|
||||
return log
|
||||
|
||||
|
||||
def get_hook_execution_logs(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
limit: int,
|
||||
) -> list[HookExecutionLog]:
|
||||
stmt = (
|
||||
select(HookExecutionLog)
|
||||
.where(HookExecutionLog.hook_id == hook_id)
|
||||
.order_by(HookExecutionLog.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def cleanup_old_execution_logs__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
max_age_days: int,
|
||||
) -> int:
|
||||
"""Delete execution logs older than max_age_days. Returns the number of rows deleted."""
|
||||
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=max_age_days
|
||||
)
|
||||
result: CursorResult = db_session.execute( # type: ignore[assignment]
|
||||
delete(HookExecutionLog)
|
||||
.where(HookExecutionLog.created_at < cutoff)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
return result.rowcount
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -144,6 +145,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -56,8 +57,8 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
# Maps schema property name to a list of highlighted snippets with match
|
||||
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
|
||||
match_highlights: dict[str, list[str]] = {}
|
||||
# Score explanation from OpenSearch when "explain": true is set in the query.
|
||||
# Contains detailed breakdown of how the score was calculated.
|
||||
# Score explanation from OpenSearch when "explain": true is set in the
|
||||
# query. Contains detailed breakdown of how the score was calculated.
|
||||
explanation: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@@ -833,9 +834,13 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
) -> list[SearchHit[DocumentChunk]]:
|
||||
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
|
||||
"""Searches the index.
|
||||
|
||||
NOTE: Does not return vector fields. In order to take advantage of
|
||||
performance benefits, the search body should exclude the schema's vector
|
||||
fields.
|
||||
|
||||
TODO(andrei): Ideally we could check that every field in the body is
|
||||
present in the index, to avoid a class of runtime bugs that could easily
|
||||
be caught during development. Or change the function signature to accept
|
||||
@@ -883,7 +888,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
raise_on_timeout=True,
|
||||
)
|
||||
|
||||
search_hits: list[SearchHit[DocumentChunk]] = []
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
|
||||
for hit in hits:
|
||||
document_chunk_source: dict[str, Any] | None = hit.get("_source")
|
||||
if not document_chunk_source:
|
||||
@@ -893,8 +898,10 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
document_chunk_score = hit.get("_score", None)
|
||||
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
|
||||
explanation: dict[str, Any] | None = hit.get("_explanation", None)
|
||||
search_hit = SearchHit[DocumentChunk](
|
||||
document_chunk=DocumentChunk.model_validate(document_chunk_source),
|
||||
search_hit = SearchHit[DocumentChunkWithoutVectors](
|
||||
document_chunk=DocumentChunkWithoutVectors.model_validate(
|
||||
document_chunk_source
|
||||
),
|
||||
score=document_chunk_score,
|
||||
match_highlights=match_highlights,
|
||||
explanation=explanation,
|
||||
|
||||
@@ -3,10 +3,6 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
from onyx.configs.app_configs import (
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_MAX_CHUNK_SIZE = 512
|
||||
|
||||
@@ -41,10 +37,10 @@ M = 32 # Set relatively high for better accuracy.
|
||||
# we have a much higher chance of all 10 of the final desired docs showing up
|
||||
# and getting scored. In worse situations, the final 10 docs don't even show up
|
||||
# as the final 10 (worse than just a miss at the reranking step).
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0
|
||||
else 750
|
||||
# Defaults to 100 for now. Initially this defaulted to 750 but we were seeing
|
||||
# poor search performance.
|
||||
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
|
||||
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 100)
|
||||
)
|
||||
|
||||
# Number of vectors to examine to decide the top k neighbors for the HNSW
|
||||
@@ -54,7 +50,7 @@ DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
|
||||
# larger than k, you can provide the size parameter to limit the final number of
|
||||
# results to k." from
|
||||
# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
|
||||
|
||||
|
||||
class HybridSearchSubqueryConfiguration(Enum):
|
||||
|
||||
@@ -47,6 +47,7 @@ from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
|
||||
from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
|
||||
@@ -117,7 +118,7 @@ def set_cluster_state(client: OpenSearchClient) -> None:
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
chunk: DocumentChunkWithoutVectors,
|
||||
score: float | None,
|
||||
highlights: dict[str, list[str]],
|
||||
) -> InferenceChunkUncleaned:
|
||||
@@ -880,7 +881,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
search_hits: list[SearchHit[DocumentChunk]] = []
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -944,7 +945,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
include_hidden=False,
|
||||
)
|
||||
normalization_pipeline_name, _ = get_normalization_pipeline_name_and_config()
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=normalization_pipeline_name,
|
||||
)
|
||||
@@ -976,7 +977,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,8 @@ from pydantic import model_serializer
|
||||
from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_REPLICAS
|
||||
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_SHARDS
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
@@ -100,9 +102,9 @@ def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
|
||||
return value
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
class DocumentChunkWithoutVectors(BaseModel):
|
||||
"""
|
||||
Represents a chunk of a document in the OpenSearch index.
|
||||
Represents a chunk of a document in the OpenSearch index without vectors.
|
||||
|
||||
The names of these fields are based on the OpenSearch schema. Changes to the
|
||||
schema require changes here. See get_document_schema.
|
||||
@@ -124,9 +126,7 @@ class DocumentChunk(BaseModel):
|
||||
|
||||
# Either both should be None or both should be non-None.
|
||||
title: str | None = None
|
||||
title_vector: list[float] | None = None
|
||||
content: str
|
||||
content_vector: list[float]
|
||||
|
||||
source_type: str
|
||||
# A list of key-value pairs separated by INDEX_SEPARATOR. See
|
||||
@@ -176,19 +176,9 @@ class DocumentChunk(BaseModel):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
|
||||
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
|
||||
f"tenant_id={self.tenant_id.tenant_id})"
|
||||
f"content length={len(self.content)}, tenant_id={self.tenant_id.tenant_id})."
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_title_and_title_vector_are_consistent(self) -> Self:
|
||||
# title and title_vector should both either be None or not.
|
||||
if self.title is not None and self.title_vector is None:
|
||||
raise ValueError("Bug: Title vector must not be None if title is not None.")
|
||||
if self.title_vector is not None and self.title is None:
|
||||
raise ValueError("Bug: Title must not be None if title vector is not None.")
|
||||
return self
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(
|
||||
self, handler: SerializerFunctionWrapHandler
|
||||
@@ -305,6 +295,35 @@ class DocumentChunk(BaseModel):
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
|
||||
class DocumentChunk(DocumentChunkWithoutVectors):
|
||||
"""Represents a chunk of a document in the OpenSearch index.
|
||||
|
||||
The names of these fields are based on the OpenSearch schema. Changes to the
|
||||
schema require changes here. See get_document_schema.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
title_vector: list[float] | None = None
|
||||
content_vector: list[float]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
|
||||
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
|
||||
f"tenant_id={self.tenant_id.tenant_id})"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_title_and_title_vector_are_consistent(self) -> Self:
|
||||
# title and title_vector should both either be None or not.
|
||||
if self.title is not None and self.title_vector is None:
|
||||
raise ValueError("Bug: Title vector must not be None if title is not None.")
|
||||
if self.title_vector is not None and self.title is None:
|
||||
raise ValueError("Bug: Title must not be None if title vector is not None.")
|
||||
return self
|
||||
|
||||
|
||||
class DocumentSchema:
|
||||
"""
|
||||
Represents the schema and indexing strategies of the OpenSearch index.
|
||||
@@ -516,78 +535,35 @@ class DocumentSchema:
|
||||
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings() -> dict[str, Any]:
|
||||
"""
|
||||
Standard settings for reasonable local index and search performance.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 1,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
Our AWS-managed OpenSearch cluster has 3 data nodes in 3 availability
|
||||
zones.
|
||||
- We use 3 shards to distribute load across all data nodes.
|
||||
- We use 2 replicas to ensure each shard has a copy in each
|
||||
availability zone. This is a hard requirement from AWS. The number
|
||||
of data copies, including the primary (not a replica) copy, must be
|
||||
divisible by the number of AZs.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 3,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch in multi-tenant cloud.
|
||||
|
||||
324 shards very roughly targets a storage load of ~30Gb per shard, which
|
||||
according to AWS OpenSearch documentation is within a good target range.
|
||||
|
||||
As documented above we need 2 replicas for a total of 3 copies of the
|
||||
data because the cluster is configured with 3-AZ awareness.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 324,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_based_on_environment() -> dict[str, Any]:
|
||||
"""
|
||||
Returns the index settings based on the environment.
|
||||
"""
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
# NOTE: The number of data copies, including the primary (not a
|
||||
# replica) copy, must be divisible by the number of AZs.
|
||||
if MULTI_TENANT:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
|
||||
)
|
||||
number_of_shards = 324
|
||||
number_of_replicas = 2
|
||||
else:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
|
||||
)
|
||||
number_of_shards = 3
|
||||
number_of_replicas = 2
|
||||
else:
|
||||
return DocumentSchema.get_index_settings()
|
||||
number_of_shards = 1
|
||||
number_of_replicas = 1
|
||||
|
||||
if OPENSEARCH_INDEX_NUM_SHARDS is not None:
|
||||
number_of_shards = OPENSEARCH_INDEX_NUM_SHARDS
|
||||
if OPENSEARCH_INDEX_NUM_REPLICAS is not None:
|
||||
number_of_replicas = OPENSEARCH_INDEX_NUM_REPLICAS
|
||||
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": number_of_shards,
|
||||
"number_of_replicas": number_of_replicas,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import UUID
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
|
||||
from onyx.configs.app_configs import OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED
|
||||
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
@@ -15,7 +16,7 @@ from onyx.context.search.models import Tag
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import ASSUMED_DOCUMENT_AGE_DAYS
|
||||
from onyx.document_index.opensearch.constants import (
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
)
|
||||
from onyx.document_index.opensearch.constants import (
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
|
||||
@@ -235,9 +236,17 @@ class DocumentQuery:
|
||||
# returning some number of results less than the index max allowed
|
||||
# return size.
|
||||
"size": DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
|
||||
"_source": get_full_document,
|
||||
# By default exclude retrieving the vector fields in order to save
|
||||
# on retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
}
|
||||
if not get_full_document:
|
||||
# If we explicitly do not want the underlying document, we will only
|
||||
# retrieve IDs.
|
||||
final_get_ids_query["_source"] = False
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_get_ids_query["profile"] = True
|
||||
|
||||
@@ -332,7 +341,7 @@ class DocumentQuery:
|
||||
|
||||
# TODO(andrei, yuhong): We can tune this more dynamically based on
|
||||
# num_hits.
|
||||
max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
max_results_per_subquery = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
|
||||
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, vector_candidates=max_results_per_subquery
|
||||
@@ -356,9 +365,6 @@ class DocumentQuery:
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
match_highlights_configuration = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
@@ -385,10 +391,19 @@ class DocumentQuery:
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
|
||||
final_hybrid_search_body["highlight"] = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# Explain is for scoring breakdowns.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_hybrid_search_body["explain"] = True
|
||||
@@ -446,6 +461,11 @@ class DocumentQuery:
|
||||
},
|
||||
"size": num_to_retrieve,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_random_search_query["profile"] = True
|
||||
@@ -460,7 +480,7 @@ class DocumentQuery:
|
||||
# search. This is higher than the number of results because the scoring
|
||||
# is hybrid. For a detailed breakdown, see where the default value is
|
||||
# set.
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns subqueries for hybrid search.
|
||||
|
||||
@@ -546,7 +566,7 @@ class DocumentQuery:
|
||||
@staticmethod
|
||||
def _get_title_vector_similarity_search_query(
|
||||
query_vector: list[float],
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"knn": {
|
||||
@@ -560,7 +580,7 @@ class DocumentQuery:
|
||||
@staticmethod
|
||||
def _get_content_vector_similarity_search_query(
|
||||
query_vector: list[float],
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"knn": {
|
||||
|
||||
@@ -88,9 +88,13 @@ def summarize_image_with_error_handling(
|
||||
try:
|
||||
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
||||
except UnsupportedImageFormatError:
|
||||
magic_hex = image_data[:8].hex() if image_data else "empty"
|
||||
logger.info(
|
||||
"Skipping image summarization due to unsupported MIME type for %s",
|
||||
"Skipping image summarization due to unsupported MIME type "
|
||||
"for %s (magic_bytes=%s, size=%d bytes)",
|
||||
context_name,
|
||||
magic_hex,
|
||||
len(image_data),
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -134,9 +138,23 @@ def _summarize_image(
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Summarization failed. Messages: {messages}"
|
||||
error_msg = error_msg[:1024]
|
||||
raise ValueError(error_msg) from e
|
||||
# Extract structured details from LiteLLM exceptions when available,
|
||||
# rather than dumping the full messages payload (which contains base64
|
||||
# image data and produces enormous, unreadable error logs).
|
||||
str_e = str(e)
|
||||
if len(str_e) > 512:
|
||||
str_e = str_e[:512] + "... (truncated)"
|
||||
parts = [f"Summarization failed: {type(e).__name__}: {str_e}"]
|
||||
status_code = getattr(e, "status_code", None)
|
||||
llm_provider = getattr(e, "llm_provider", None)
|
||||
model = getattr(e, "model", None)
|
||||
if status_code is not None:
|
||||
parts.append(f"status_code={status_code}")
|
||||
if llm_provider is not None:
|
||||
parts.append(f"llm_provider={llm_provider}")
|
||||
if model is not None:
|
||||
parts.append(f"model={model}")
|
||||
raise ValueError(" | ".join(parts)) from e
|
||||
|
||||
|
||||
def _encode_image_for_llm_prompt(image_data: bytes) -> str:
|
||||
|
||||
127
backend/onyx/hooks/models.py
Normal file
127
backend/onyx/hooks/models.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
from pydantic import SecretStr
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
|
||||
NonEmptySecretStr = Annotated[SecretStr, Field(min_length=1)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookCreateRequest(BaseModel):
|
||||
name: str = Field(min_length=1)
|
||||
hook_point: HookPoint
|
||||
endpoint_url: str = Field(min_length=1)
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = None # if None, uses HookPointSpec default
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None, uses HookPointSpec default
|
||||
|
||||
@field_validator("name", "endpoint_url")
|
||||
@classmethod
|
||||
def no_whitespace_only(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("cannot be whitespace-only.")
|
||||
return v
|
||||
|
||||
|
||||
class HookUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = (
|
||||
None # if None in model_fields_set, reset to spec default
|
||||
)
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None in model_fields_set, reset to spec default
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_field(self) -> "HookUpdateRequest":
|
||||
if not self.model_fields_set:
|
||||
raise ValueError("At least one field must be provided for an update.")
|
||||
if "name" in self.model_fields_set and not (self.name or "").strip():
|
||||
raise ValueError("name cannot be cleared.")
|
||||
if (
|
||||
"endpoint_url" in self.model_fields_set
|
||||
and not (self.endpoint_url or "").strip()
|
||||
):
|
||||
raise ValueError("endpoint_url cannot be cleared.")
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookPointMetaResponse(BaseModel):
|
||||
hook_point: HookPoint
|
||||
display_name: str
|
||||
description: str
|
||||
docs_url: str | None
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
default_timeout_seconds: float
|
||||
default_fail_strategy: HookFailStrategy
|
||||
fail_hard_description: str
|
||||
|
||||
|
||||
class HookResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
hook_point: HookPoint
|
||||
# Nullable to match the DB column — endpoint_url is required on creation but
|
||||
# future hook point types may not use an external endpoint (e.g. built-in handlers).
|
||||
endpoint_url: str | None
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
creator_email: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class HookValidateResponse(BaseModel):
|
||||
success: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookHealthStatus(str, Enum):
|
||||
healthy = "healthy" # green — reachable, no failures in last 1h
|
||||
degraded = "degraded" # yellow — reachable, failures in last 1h
|
||||
unreachable = "unreachable" # red — is_reachable=false or null
|
||||
|
||||
|
||||
class HookFailureRecord(BaseModel):
|
||||
error_message: str | None = None
|
||||
status_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class HookHealthResponse(BaseModel):
|
||||
status: HookHealthStatus
|
||||
recent_failures: list[HookFailureRecord] = Field(
|
||||
default_factory=list,
|
||||
description="Last 10 failures, newest first",
|
||||
max_length=10,
|
||||
)
|
||||
0
backend/onyx/hooks/points/__init__.py
Normal file
0
backend/onyx/hooks/points/__init__.py
Normal file
59
backend/onyx/hooks/points/base.py
Normal file
59
backend/onyx/hooks/points/base.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
|
||||
|
||||
_REQUIRED_ATTRS = (
|
||||
"hook_point",
|
||||
"display_name",
|
||||
"description",
|
||||
"default_timeout_seconds",
|
||||
"fail_hard_description",
|
||||
"default_fail_strategy",
|
||||
)
|
||||
|
||||
|
||||
class HookPointSpec(ABC):
|
||||
"""Static metadata and contract for a pipeline hook point.
|
||||
|
||||
This is NOT a regular class meant for direct instantiation by callers.
|
||||
Each concrete subclass represents exactly one hook point and is instantiated
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
|
||||
should ever create instances directly — use get_hook_point_spec() or
|
||||
get_all_specs() from the registry instead.
|
||||
|
||||
Each hook point is a concrete subclass of this class. Onyx engineers
|
||||
own these definitions — customers never touch this code.
|
||||
|
||||
Subclasses must define all attributes as class-level constants.
|
||||
"""
|
||||
|
||||
hook_point: HookPoint
|
||||
display_name: str
|
||||
description: str
|
||||
default_timeout_seconds: float
|
||||
fail_hard_description: str
|
||||
default_fail_strategy: HookFailStrategy
|
||||
docs_url: str | None = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Skip intermediate abstract subclasses — they may still be partially defined.
|
||||
if getattr(cls, "__abstractmethods__", None):
|
||||
return
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
if missing:
|
||||
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the request payload sent to the customer's endpoint."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the expected response from the customer's endpoint."""
|
||||
29
backend/onyx/hooks/points/document_ingestion.py
Normal file
29
backend/onyx/hooks/points/document_ingestion.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
|
||||
"""
|
||||
|
||||
hook_point = HookPoint.DOCUMENT_INGESTION
|
||||
display_name = "Document Ingestion"
|
||||
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define input schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define output schema
|
||||
return {"type": "object", "properties": {}}
|
||||
83
backend/onyx/hooks/points/query_processing.py
Normal file
83
backend/onyx/hooks/points/query_processing.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class QueryProcessingSpec(HookPointSpec):
|
||||
"""Hook point that runs on every user query before it enters the pipeline.
|
||||
|
||||
Call site: inside handle_stream_message_objects() in
|
||||
backend/onyx/chat/process_message.py, immediately after message_text is
|
||||
assigned from the request and before create_new_chat_message() saves it.
|
||||
|
||||
This is the earliest possible point in the query pipeline:
|
||||
- Raw query — unmodified, exactly as the user typed it
|
||||
- No side effects yet — message has not been saved to DB
|
||||
- User identity is available for user-specific logic
|
||||
|
||||
Supported use cases:
|
||||
- Query rejection: block queries based on content or user context
|
||||
- Query rewriting: normalize, expand, or modify the query
|
||||
- PII removal: scrub sensitive data before the LLM sees it
|
||||
- Access control: reject queries from certain users or groups
|
||||
- Query auditing: log or track queries based on business rules
|
||||
"""
|
||||
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
display_name = "Query Processing"
|
||||
description = (
|
||||
"Runs on every user query before it enters the pipeline. "
|
||||
"Allows rewriting, filtering, or rejecting queries."
|
||||
)
|
||||
default_timeout_seconds = 5.0 # user is actively waiting — keep tight
|
||||
fail_hard_description = (
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The raw query string exactly as the user typed it.",
|
||||
},
|
||||
"user_email": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Email of the user submitting the query, or null if unauthenticated.",
|
||||
},
|
||||
"chat_session_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
|
||||
},
|
||||
},
|
||||
"required": ["query", "user_email", "chat_session_id"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"The (optionally modified) query to use. "
|
||||
"Set to null to reject the query."
|
||||
),
|
||||
},
|
||||
"rejection_message": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Message shown to the user when query is null. "
|
||||
"Falls back to a generic message if not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
45
backend/onyx/hooks/registry.py
Normal file
45
backend/onyx/hooks/registry.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionSpec
|
||||
from onyx.hooks.points.query_processing import QueryProcessingSpec
|
||||
|
||||
# Internal: use `monkeypatch.setattr(registry_module, "_REGISTRY", {...})` to override in tests.
|
||||
_REGISTRY: dict[HookPoint, HookPointSpec] = {
|
||||
HookPoint.DOCUMENT_INGESTION: DocumentIngestionSpec(),
|
||||
HookPoint.QUERY_PROCESSING: QueryProcessingSpec(),
|
||||
}
|
||||
|
||||
|
||||
def validate_registry() -> None:
|
||||
"""Assert that every HookPoint enum value has a registered spec.
|
||||
|
||||
Call once at application startup (e.g. from the FastAPI lifespan hook).
|
||||
Raises RuntimeError if any hook point is missing a spec.
|
||||
"""
|
||||
missing = set(HookPoint) - set(_REGISTRY)
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"Hook point(s) have no registered spec: {missing}. "
|
||||
"Add an entry to onyx.hooks.registry._REGISTRY."
|
||||
)
|
||||
|
||||
|
||||
def get_hook_point_spec(hook_point: HookPoint) -> HookPointSpec:
|
||||
"""Returns the spec for a given hook point.
|
||||
|
||||
Raises ValueError if the hook point has no registered spec — this is a
|
||||
programmer error; every HookPoint enum value must have a corresponding spec
|
||||
in _REGISTRY.
|
||||
"""
|
||||
try:
|
||||
return _REGISTRY[hook_point]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"No spec registered for hook point {hook_point!r}. "
|
||||
"Add an entry to onyx.hooks.registry._REGISTRY."
|
||||
)
|
||||
|
||||
|
||||
def get_all_specs() -> list[HookPointSpec]:
|
||||
"""Returns the specs for all registered hook points."""
|
||||
return list(_REGISTRY.values())
|
||||
@@ -395,6 +395,12 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
llm = get_default_llm_with_vision()
|
||||
|
||||
if not llm:
|
||||
if get_image_extraction_and_analysis_enabled():
|
||||
logger.warning(
|
||||
"Image analysis is enabled but no vision-capable LLM is "
|
||||
"available — images will not be summarized. Configure a "
|
||||
"vision model in the admin LLM settings."
|
||||
)
|
||||
# Even without LLM, we still convert to IndexingDocument with base Sections
|
||||
return [
|
||||
IndexingDocument(
|
||||
|
||||
@@ -168,10 +168,23 @@ def get_default_llm_with_vision(
|
||||
if model_supports_image_input(
|
||||
default_model.name, default_model.llm_provider.provider
|
||||
):
|
||||
logger.info(
|
||||
"Using default vision model: %s (provider=%s)",
|
||||
default_model.name,
|
||||
default_model.llm_provider.provider,
|
||||
)
|
||||
return create_vision_llm(
|
||||
LLMProviderView.from_model(default_model.llm_provider),
|
||||
default_model.name,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Default vision model %s (provider=%s) does not support "
|
||||
"image input — falling back to searching all providers",
|
||||
default_model.name,
|
||||
default_model.llm_provider.provider,
|
||||
)
|
||||
|
||||
# Fall back to searching all providers
|
||||
models = fetch_existing_models(
|
||||
db_session=db_session,
|
||||
@@ -179,6 +192,10 @@ def get_default_llm_with_vision(
|
||||
)
|
||||
|
||||
if not models:
|
||||
logger.warning(
|
||||
"No LLM models with VISION or CHAT flow type found — "
|
||||
"image summarization will be disabled"
|
||||
)
|
||||
return None
|
||||
|
||||
for model in models:
|
||||
@@ -200,11 +217,25 @@ def get_default_llm_with_vision(
|
||||
|
||||
for model in sorted_models:
|
||||
if model_supports_image_input(model.name, model.llm_provider.provider):
|
||||
logger.info(
|
||||
"Using fallback vision model: %s (provider=%s)",
|
||||
model.name,
|
||||
model.llm_provider.provider,
|
||||
)
|
||||
return create_vision_llm(
|
||||
provider_map[model.llm_provider_id],
|
||||
model.name,
|
||||
)
|
||||
|
||||
checked_models = [
|
||||
f"{m.name} (provider={m.llm_provider.provider})" for m in sorted_models
|
||||
]
|
||||
logger.warning(
|
||||
"No vision-capable model found among %d candidates: %s — "
|
||||
"image summarization will be disabled",
|
||||
len(sorted_models),
|
||||
", ".join(checked_models),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.hooks.registry import validate_registry
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
@@ -308,6 +309,7 @@ def validate_no_vector_db_settings() -> None:
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
validate_registry()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
|
||||
@@ -479,7 +479,9 @@ def is_zip_file(file: UploadFile) -> bool:
|
||||
|
||||
|
||||
def upload_files(
|
||||
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
|
||||
files: list[UploadFile],
|
||||
file_origin: FileOrigin = FileOrigin.CONNECTOR,
|
||||
unzip: bool = True,
|
||||
) -> FileUploadResponse:
|
||||
|
||||
# Skip directories and known macOS metadata entries
|
||||
@@ -502,31 +504,46 @@ def upload_files(
|
||||
if seen_zip:
|
||||
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
|
||||
seen_zip = True
|
||||
|
||||
# Validate the zip by opening it (catches corrupt/non-zip files)
|
||||
with zipfile.ZipFile(file.file, "r") as zf:
|
||||
zip_metadata_file_id = save_zip_metadata_to_file_store(
|
||||
zf, file_store
|
||||
)
|
||||
for file_info in zf.namelist():
|
||||
if zf.getinfo(file_info).is_dir():
|
||||
continue
|
||||
|
||||
if not should_process_file(file_info):
|
||||
continue
|
||||
|
||||
sub_file_bytes = zf.read(file_info)
|
||||
|
||||
mime_type, __ = mimetypes.guess_type(file_info)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(sub_file_bytes),
|
||||
display_name=os.path.basename(file_info),
|
||||
file_origin=file_origin,
|
||||
file_type=mime_type,
|
||||
if unzip:
|
||||
zip_metadata_file_id = save_zip_metadata_to_file_store(
|
||||
zf, file_store
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(os.path.basename(file_info))
|
||||
for file_info in zf.namelist():
|
||||
if zf.getinfo(file_info).is_dir():
|
||||
continue
|
||||
|
||||
if not should_process_file(file_info):
|
||||
continue
|
||||
|
||||
sub_file_bytes = zf.read(file_info)
|
||||
|
||||
mime_type, __ = mimetypes.guess_type(file_info)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(sub_file_bytes),
|
||||
display_name=os.path.basename(file_info),
|
||||
file_origin=file_origin,
|
||||
file_type=mime_type,
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(os.path.basename(file_info))
|
||||
continue
|
||||
|
||||
# Store the zip as-is (unzip=False)
|
||||
file.file.seek(0)
|
||||
file_id = file_store.save_file(
|
||||
content=file.file,
|
||||
display_name=file.filename,
|
||||
file_origin=file_origin,
|
||||
file_type=file.content_type or "application/zip",
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(file.filename)
|
||||
continue
|
||||
|
||||
# Since we can't render docx files in the UI,
|
||||
@@ -613,9 +630,10 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
unzip: bool = True,
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files, FileOrigin.OTHER)
|
||||
return upload_files(files, FileOrigin.OTHER, unzip=unzip)
|
||||
|
||||
|
||||
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -74,7 +74,7 @@ def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
|
||||
|
||||
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
|
||||
"""helper function to return an id given a string input"""
|
||||
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
|
||||
hash_obj = hashlib.md5(hash_input.encode("utf-8"), usedforsecurity=False)
|
||||
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
|
||||
|
||||
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case
|
||||
|
||||
@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.8.0
|
||||
pypdf==6.9.1
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
|
||||
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
External dependency unit tests for user file delete queue protections.
|
||||
|
||||
Verifies that the three mechanisms added to check_for_user_file_delete work
|
||||
correctly:
|
||||
|
||||
1. Queue depth backpressure – when the broker queue exceeds
|
||||
USER_FILE_DELETE_MAX_QUEUE_DEPTH, no new tasks are enqueued.
|
||||
|
||||
2. Per-file Redis guard key – if the guard key for a file already exists in
|
||||
Redis, that file is skipped even though it is still in DELETING status.
|
||||
|
||||
3. Task expiry – every send_task call carries expires=
|
||||
CELERY_USER_FILE_DELETE_TASK_EXPIRES so that stale queued tasks are
|
||||
discarded by workers automatically.
|
||||
|
||||
Also verifies that delete_user_file_impl clears the guard key the moment
|
||||
it is picked up by a worker.
|
||||
|
||||
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
|
||||
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
|
||||
on the task class so no real broker is needed.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_delete_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_delete_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_for_user_file_delete,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file_delete,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_QUEUE_LEN = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
|
||||
)
|
||||
|
||||
|
||||
def _create_deleting_user_file(db_session: Session, user_id: object) -> UserFile:
|
||||
"""Insert a UserFile in DELETING status and return it."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.DELETING,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on *task*'s class so that ``self.app``
|
||||
inside the task function returns *mock_app*.
|
||||
|
||||
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteQueueDepthBackpressure:
|
||||
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
|
||||
|
||||
def test_no_tasks_enqueued_when_queue_over_limit(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When the queue depth exceeds the limit the beat cycle is skipped."""
|
||||
user = create_test_user(db_session, "del_bp_user")
|
||||
_create_deleting_user_file(db_session, user.id)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with (
|
||||
_patch_task_app(check_for_user_file_delete, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=USER_FILE_DELETE_MAX_QUEUE_DEPTH + 1),
|
||||
):
|
||||
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
mock_app.send_task.assert_not_called()
|
||||
|
||||
|
||||
class TestDeletePerFileGuardKey:
|
||||
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
|
||||
|
||||
def test_guarded_file_not_re_enqueued(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file whose guard key is already set in Redis is skipped."""
|
||||
user = create_test_user(db_session, "del_guard_user")
|
||||
uf = _create_deleting_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_delete_queued_key(uf.id)
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_for_user_file_delete, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# send_task must not have been called with this specific file's ID
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
kwargs = call.kwargs.get("kwargs", {})
|
||||
assert kwargs.get("user_file_id") != str(
|
||||
uf.id
|
||||
), f"File {uf.id} should have been skipped because its guard key exists"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
def test_guard_key_exists_in_redis_after_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a file is enqueued its guard key is present in Redis with a TTL."""
|
||||
user = create_test_user(db_session, "del_guard_set_user")
|
||||
uf = _create_deleting_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_delete_queued_key(uf.id)
|
||||
redis_client.delete(guard_key) # clean slate
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_for_user_file_delete, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
assert redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be set in Redis after enqueue"
|
||||
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
|
||||
assert (
|
||||
0 < ttl <= CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
), f"Guard key TTL {ttl}s is outside the expected range (0, {CELERY_USER_FILE_DELETE_TASK_EXPIRES}]"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestDeleteTaskExpiry:
|
||||
"""Protection 3: every send_task call includes an expires value."""
|
||||
|
||||
def test_send_task_called_with_expires(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""send_task is called with the correct queue, task name, and expires."""
|
||||
user = create_test_user(db_session, "del_expires_user")
|
||||
uf = _create_deleting_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_delete_queued_key(uf.id)
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_for_user_file_delete, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_for_user_file_delete.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# At least one task should have been submitted (for our file)
|
||||
assert (
|
||||
mock_app.send_task.call_count >= 1
|
||||
), "Expected at least one task to be submitted"
|
||||
|
||||
# Every submitted task must carry expires
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.DELETE_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_DELETE
|
||||
assert (
|
||||
call.kwargs.get("expires") == CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
), "Task must be submitted with the correct expires value to prevent stale task accumulation"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestDeleteWorkerClearsGuardKey:
|
||||
"""process_single_user_file_delete removes the guard key when it picks up a task."""
|
||||
|
||||
def test_guard_key_deleted_on_pickup(
|
||||
self,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""The guard key is deleted before the worker does any real work.
|
||||
|
||||
We simulate an already-locked file so delete_user_file_impl returns
|
||||
early – but crucially, after the guard key deletion.
|
||||
"""
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_delete_queued_key(user_file_id)
|
||||
|
||||
# Simulate the guard key set when the beat enqueued the task
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_DELETE_TASK_EXPIRES, 1)
|
||||
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
|
||||
|
||||
# Hold the per-file delete lock so the worker exits early without
|
||||
# touching the database or file store.
|
||||
lock_key = _user_file_delete_lock_key(user_file_id)
|
||||
delete_lock = redis_client.lock(lock_key, timeout=10)
|
||||
acquired = delete_lock.acquire(blocking=False)
|
||||
assert acquired, "Should be able to acquire the delete lock for this test"
|
||||
|
||||
try:
|
||||
process_single_user_file_delete.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
finally:
|
||||
if delete_lock.owned():
|
||||
delete_lock.release()
|
||||
|
||||
assert not redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be deleted when the worker picks up the task"
|
||||
@@ -297,6 +297,10 @@ def index_batch_params(
|
||||
class TestDocumentIndexOld:
|
||||
"""Tests the old DocumentIndex interface."""
|
||||
|
||||
# TODO(ENG-3864)(andrei): Re-enable this test.
|
||||
@pytest.mark.xfail(
|
||||
reason="Flaky test: Retrieved chunks vary non-deterministically before and after changing user projects and personas. Likely a timing issue with the index being updated."
|
||||
)
|
||||
def test_update_single_can_clear_user_projects_and_personas(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
)
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
|
||||
from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
@@ -96,6 +97,23 @@ def _patch_hybrid_search_normalization_pipeline(
|
||||
)
|
||||
|
||||
|
||||
def _patch_opensearch_match_highlights_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch, disabled: bool
|
||||
) -> None:
|
||||
"""
|
||||
Patches OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED wherever necessary for this
|
||||
test file.
|
||||
"""
|
||||
monkeypatch.setattr(
|
||||
"onyx.configs.app_configs.OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED",
|
||||
disabled,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"onyx.document_index.opensearch.search.OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED",
|
||||
disabled,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_document_chunk(
|
||||
document_id: str,
|
||||
content: str,
|
||||
@@ -226,7 +244,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
|
||||
# Under test.
|
||||
# Should not raise.
|
||||
@@ -242,7 +260,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
@@ -271,7 +289,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
@@ -285,7 +303,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
|
||||
# Under test and postcondition.
|
||||
# Should return False before creation.
|
||||
@@ -305,7 +323,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
@@ -340,7 +358,7 @@ class TestOpenSearchClient:
|
||||
},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
@@ -383,7 +401,7 @@ class TestOpenSearchClient:
|
||||
"test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test and postcondition.
|
||||
@@ -418,7 +436,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
# Create once - should succeed.
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
@@ -461,7 +479,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
doc = _create_test_document_chunk(
|
||||
@@ -489,7 +507,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
docs = [
|
||||
@@ -520,7 +538,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
doc = _create_test_document_chunk(
|
||||
@@ -548,7 +566,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
original_doc = _create_test_document_chunk(
|
||||
@@ -583,7 +601,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=False
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test and postcondition.
|
||||
@@ -602,7 +620,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
doc = _create_test_document_chunk(
|
||||
@@ -638,7 +656,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
@@ -659,7 +677,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index multiple documents.
|
||||
@@ -735,7 +753,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Create a document to update.
|
||||
@@ -784,7 +802,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test and postcondition.
|
||||
@@ -804,11 +822,12 @@ class TestOpenSearchClient:
|
||||
"""Tests all hybrid search configurations and pipelines."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
# Index documents.
|
||||
docs = {
|
||||
@@ -881,8 +900,12 @@ class TestOpenSearchClient:
|
||||
)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for i, chunk in enumerate(results):
|
||||
assert (
|
||||
chunk.document_chunk == docs[chunk.document_chunk.document_id]
|
||||
expected = docs[chunk.document_chunk.document_id]
|
||||
assert chunk.document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(expected, k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
@@ -906,7 +929,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
# Note no documents were indexed.
|
||||
|
||||
@@ -942,12 +965,13 @@ class TestOpenSearchClient:
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
@@ -1038,7 +1062,12 @@ class TestOpenSearchClient:
|
||||
# ordered; we're just assuming which doc will be the first result here.
|
||||
assert results[0].document_chunk.document_id == "public-doc"
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0].document_chunk == docs["public-doc"]
|
||||
assert results[0].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["public-doc"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
# Make sure score reporting seems reasonable (it should not be None
|
||||
# or 0).
|
||||
assert results[0].score
|
||||
@@ -1046,7 +1075,12 @@ class TestOpenSearchClient:
|
||||
assert results[0].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
# Same for the second result.
|
||||
assert results[1].document_chunk.document_id == "private-doc-user-a"
|
||||
assert results[1].document_chunk == docs["private-doc-user-a"]
|
||||
assert results[1].document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(docs["private-doc-user-a"], k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
assert results[1].score
|
||||
assert results[1].match_highlights.get(CONTENT_FIELD_NAME, [])
|
||||
|
||||
@@ -1062,11 +1096,12 @@ class TestOpenSearchClient:
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
_patch_opensearch_match_highlights_disabled(monkeypatch, False)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with varying relevance to the query.
|
||||
@@ -1193,7 +1228,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Although very unlikely in practice, let's use the same doc ID just to
|
||||
@@ -1286,7 +1321,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Don't index any documents.
|
||||
@@ -1313,7 +1348,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index chunks for two different documents.
|
||||
@@ -1381,7 +1416,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
@@ -1458,7 +1493,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index docs with various ages.
|
||||
@@ -1550,7 +1585,7 @@ class TestOpenSearchClient:
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index chunks for two different documents, one hidden one not.
|
||||
@@ -1599,4 +1634,9 @@ class TestOpenSearchClient:
|
||||
for result in results:
|
||||
# Note each result must be from doc 1, which is not hidden.
|
||||
expected_result = doc1_chunks[result.document_chunk.chunk_index]
|
||||
assert result.document_chunk == expected_result
|
||||
assert result.document_chunk == DocumentChunkWithoutVectors(
|
||||
**{
|
||||
k: getattr(expected_result, k)
|
||||
for k in DocumentChunkWithoutVectors.model_fields
|
||||
}
|
||||
)
|
||||
|
||||
@@ -31,7 +31,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
)
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.configs.constants import SOURCE_TYPE
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import OpenSearchDocumentMigrationRecord
|
||||
@@ -44,6 +43,7 @@ from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
@@ -70,6 +70,7 @@ from onyx.document_index.vespa_constants import SOURCE_LINKS
|
||||
from onyx.document_index.vespa_constants import TITLE
|
||||
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from tests.external_dependency_unit.full_setup import ensure_full_deployment_setup
|
||||
|
||||
@@ -78,24 +79,22 @@ CHUNK_COUNT = 5
|
||||
|
||||
|
||||
def _get_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
document_id: str,
|
||||
tenant_state: TenantState,
|
||||
) -> list[DocumentChunk]:
|
||||
opensearch_client.refresh_index()
|
||||
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=document_id,
|
||||
tenant_state=TenantState(tenant_id=current_tenant_id, multitenant=False),
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
)
|
||||
search_hits = opensearch_client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
return [search_hit.document_chunk for search_hit in search_hits]
|
||||
results: list[DocumentChunk] = []
|
||||
for i in range(CHUNK_COUNT):
|
||||
document_chunk_id: str = get_opensearch_doc_chunk_id(
|
||||
tenant_state=tenant_state,
|
||||
document_id=document_id,
|
||||
chunk_index=i,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
)
|
||||
result = opensearch_client.get_document(document_chunk_id)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _delete_document_chunks_from_opensearch(
|
||||
@@ -452,10 +451,13 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
for chunks in document_chunks.values():
|
||||
all_chunks.extend(chunks)
|
||||
vespa_document_index.index_raw_chunks(all_chunks)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
|
||||
# Under test.
|
||||
result = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
@@ -477,7 +479,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
# Verify chunks were indexed in OpenSearch.
|
||||
for document in test_documents:
|
||||
opensearch_chunks = _get_document_chunks_from_opensearch(
|
||||
opensearch_client, document.id, get_current_tenant_id()
|
||||
opensearch_client, document.id, tenant_state
|
||||
)
|
||||
assert len(opensearch_chunks) == CHUNK_COUNT
|
||||
opensearch_chunks.sort(key=lambda x: x.chunk_index)
|
||||
@@ -522,6 +524,9 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
for chunks in document_chunks.values():
|
||||
all_chunks.extend(chunks)
|
||||
vespa_document_index.index_raw_chunks(all_chunks)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
|
||||
# Run the initial batch. To simulate partial progress we will mock the
|
||||
# redis lock to return True for the first invocation of .owned() and
|
||||
@@ -536,7 +541,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
return_value=mock_redis_client,
|
||||
):
|
||||
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
assert result_1 is True
|
||||
@@ -559,7 +564,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
# Under test.
|
||||
# Run the remainder of the migration.
|
||||
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
@@ -583,7 +588,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
# Verify chunks were indexed in OpenSearch.
|
||||
for document in test_documents:
|
||||
opensearch_chunks = _get_document_chunks_from_opensearch(
|
||||
opensearch_client, document.id, get_current_tenant_id()
|
||||
opensearch_client, document.id, tenant_state
|
||||
)
|
||||
assert len(opensearch_chunks) == CHUNK_COUNT
|
||||
opensearch_chunks.sort(key=lambda x: x.chunk_index)
|
||||
@@ -630,6 +635,9 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
for chunks in document_chunks.values():
|
||||
all_chunks.extend(chunks)
|
||||
vespa_document_index.index_raw_chunks(all_chunks)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
|
||||
# Run the initial batch. To simulate partial progress we will mock the
|
||||
# redis lock to return True for the first invocation of .owned() and
|
||||
@@ -646,7 +654,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
return_value=mock_redis_client,
|
||||
):
|
||||
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
assert result_1 is True
|
||||
@@ -691,7 +699,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
),
|
||||
):
|
||||
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
@@ -728,7 +736,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
),
|
||||
):
|
||||
result_3 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
@@ -752,7 +760,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
# Verify chunks were indexed in OpenSearch.
|
||||
for document in test_documents:
|
||||
opensearch_chunks = _get_document_chunks_from_opensearch(
|
||||
opensearch_client, document.id, get_current_tenant_id()
|
||||
opensearch_client, document.id, tenant_state
|
||||
)
|
||||
assert len(opensearch_chunks) == CHUNK_COUNT
|
||||
opensearch_chunks.sort(key=lambda x: x.chunk_index)
|
||||
@@ -840,24 +848,25 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
chunk["content"] = (
|
||||
f"Different content {chunk[CHUNK_ID]} for {test_documents[0].id}"
|
||||
)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
chunks_for_document_in_opensearch, _ = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
document_in_opensearch,
|
||||
TenantState(tenant_id=get_current_tenant_id(), multitenant=False),
|
||||
tenant_state,
|
||||
{},
|
||||
)
|
||||
)
|
||||
opensearch_client.bulk_index_documents(
|
||||
documents=chunks_for_document_in_opensearch,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=False
|
||||
),
|
||||
tenant_state=tenant_state,
|
||||
update_if_exists=True,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
result = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
@@ -878,7 +887,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
# Verify chunks were indexed in OpenSearch.
|
||||
for document in test_documents:
|
||||
opensearch_chunks = _get_document_chunks_from_opensearch(
|
||||
opensearch_client, document.id, get_current_tenant_id()
|
||||
opensearch_client, document.id, tenant_state
|
||||
)
|
||||
assert len(opensearch_chunks) == CHUNK_COUNT
|
||||
opensearch_chunks.sort(key=lambda x: x.chunk_index)
|
||||
@@ -922,11 +931,14 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
for chunks in document_chunks.values():
|
||||
all_chunks.extend(chunks)
|
||||
vespa_document_index.index_raw_chunks(all_chunks)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
|
||||
# Under test.
|
||||
# First run.
|
||||
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
@@ -947,7 +959,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
# Verify chunks were indexed in OpenSearch.
|
||||
for document in test_documents:
|
||||
opensearch_chunks = _get_document_chunks_from_opensearch(
|
||||
opensearch_client, document.id, get_current_tenant_id()
|
||||
opensearch_client, document.id, tenant_state
|
||||
)
|
||||
assert len(opensearch_chunks) == CHUNK_COUNT
|
||||
opensearch_chunks.sort(key=lambda x: x.chunk_index)
|
||||
@@ -960,7 +972,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
# Under test.
|
||||
# Second run.
|
||||
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
tenant_id=tenant_state.tenant_id
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
@@ -982,7 +994,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
# Verify chunks were indexed in OpenSearch.
|
||||
for document in test_documents:
|
||||
opensearch_chunks = _get_document_chunks_from_opensearch(
|
||||
opensearch_client, document.id, get_current_tenant_id()
|
||||
opensearch_client, document.id, tenant_state
|
||||
)
|
||||
assert len(opensearch_chunks) == CHUNK_COUNT
|
||||
opensearch_chunks.sort(key=lambda x: x.chunk_index)
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Unit test verifying that the upload API path sends tasks with expires=.
|
||||
|
||||
The upload_files_to_user_files_with_indexing function must include expires=
|
||||
on every send_task call to prevent phantom task accumulation if the worker
|
||||
is down or slow.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
|
||||
|
||||
def _make_mock_user_file() -> MagicMock:
|
||||
uf = MagicMock(spec=UserFile)
|
||||
uf.id = str(uuid4())
|
||||
return uf
|
||||
|
||||
|
||||
@patch("onyx.db.projects.get_current_tenant_id", return_value="test_tenant")
|
||||
@patch("onyx.db.projects.create_user_files")
|
||||
@patch(
|
||||
"onyx.background.celery.versioned_apps.client.app",
|
||||
new_callable=MagicMock,
|
||||
)
|
||||
def test_send_task_includes_expires(
|
||||
mock_client_app: MagicMock,
|
||||
mock_create: MagicMock,
|
||||
mock_tenant: MagicMock, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Every send_task call from the upload path must include expires=."""
|
||||
user_files = [_make_mock_user_file(), _make_mock_user_file()]
|
||||
mock_create.return_value = MagicMock(
|
||||
user_files=user_files,
|
||||
rejected_files=[],
|
||||
id_to_temp_id={},
|
||||
)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_db_session = MagicMock()
|
||||
|
||||
upload_files_to_user_files_with_indexing(
|
||||
files=[],
|
||||
project_id=None,
|
||||
user=mock_user,
|
||||
temp_id_map=None,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert mock_client_app.send_task.call_count == len(user_files)
|
||||
|
||||
for call in mock_client_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
|
||||
assert (
|
||||
call.kwargs.get("expires") == CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
), "send_task must include expires= to prevent phantom task accumulation"
|
||||
@@ -1,249 +0,0 @@
|
||||
"""Unit tests for OpenSearchDocumentIndex.index().
|
||||
|
||||
These tests mock the OpenSearch client and verify the buffered
|
||||
flush-by-document logic, DocumentInsertionRecord construction, and
|
||||
delete-before-insert semantics.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int = 0,
|
||||
content: str = "test content",
|
||||
chunk_count: int | None = None,
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_doc",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.NOT_APPLICABLE,
|
||||
metadata={},
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
index_chunk = IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=[0.1] * 10,
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=index_chunk,
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
|
||||
def _make_indexing_metadata(
|
||||
doc_ids: list[str],
|
||||
old_counts: list[int],
|
||||
new_counts: list[int],
|
||||
) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old,
|
||||
new_chunk_cnt=new,
|
||||
)
|
||||
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _make_os_index(mock_client: MagicMock) -> OpenSearchDocumentIndex:
|
||||
"""Create an OpenSearchDocumentIndex with a mocked client."""
|
||||
with patch.object(
|
||||
OpenSearchDocumentIndex,
|
||||
"__init__",
|
||||
lambda _self, *_a, **_kw: None,
|
||||
):
|
||||
idx = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
|
||||
|
||||
idx._index_name = "test_index"
|
||||
idx._tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
|
||||
idx._client = mock_client
|
||||
return idx
|
||||
|
||||
|
||||
def test_index_single_new_doc() -> None:
|
||||
"""Indexing a single new document returns one record with already_existed=False."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents.return_value = None
|
||||
|
||||
idx = _make_os_index(mock_client)
|
||||
|
||||
# Patch delete to return 0 (no existing chunks)
|
||||
with patch.object(idx, "delete", return_value=0) as mock_delete:
|
||||
chunk = _make_chunk("doc1")
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[1])
|
||||
|
||||
results = idx.index(chunks=[chunk], indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == "doc1"
|
||||
assert results[0].already_existed is False
|
||||
mock_delete.assert_called_once()
|
||||
mock_client.bulk_index_documents.assert_called_once()
|
||||
|
||||
|
||||
def test_index_existing_doc_already_existed_true() -> None:
|
||||
"""Re-indexing a doc with previous chunks returns already_existed=True."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents.return_value = None
|
||||
|
||||
idx = _make_os_index(mock_client)
|
||||
|
||||
with patch.object(idx, "delete", return_value=5):
|
||||
chunk = _make_chunk("doc1")
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[5], new_counts=[1])
|
||||
|
||||
results = idx.index(chunks=[chunk], indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].already_existed is True
|
||||
|
||||
|
||||
def test_index_multiple_docs_flushed_separately() -> None:
|
||||
"""Chunks from different documents are flushed in separate bulk calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents.return_value = None
|
||||
|
||||
idx = _make_os_index(mock_client)
|
||||
|
||||
with patch.object(idx, "delete", return_value=0):
|
||||
chunks = [
|
||||
_make_chunk("doc1", chunk_id=0),
|
||||
_make_chunk("doc1", chunk_id=1),
|
||||
_make_chunk("doc2", chunk_id=0),
|
||||
]
|
||||
metadata = _make_indexing_metadata(
|
||||
["doc1", "doc2"], old_counts=[0, 0], new_counts=[2, 1]
|
||||
)
|
||||
|
||||
results = idx.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
result_map = {r.document_id: r.already_existed for r in results}
|
||||
assert len(result_map) == 2
|
||||
assert result_map["doc1"] is False
|
||||
assert result_map["doc2"] is False
|
||||
# Two separate flushes: one for doc1 (2 chunks), one for doc2 (1 chunk)
|
||||
assert mock_client.bulk_index_documents.call_count == 2
|
||||
|
||||
|
||||
def test_index_deletes_before_inserting() -> None:
|
||||
"""For each document, delete is called before bulk_index_documents."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents.return_value = None
|
||||
|
||||
call_order: list[str] = []
|
||||
|
||||
idx = _make_os_index(mock_client)
|
||||
|
||||
def track_delete(*_args: object, **_kwargs: object) -> int:
|
||||
call_order.append("delete")
|
||||
return 3
|
||||
|
||||
def track_bulk(*_args: object, **_kwargs: object) -> None:
|
||||
call_order.append("bulk_index")
|
||||
|
||||
mock_client.bulk_index_documents.side_effect = track_bulk
|
||||
|
||||
with patch.object(idx, "delete", side_effect=track_delete):
|
||||
chunk = _make_chunk("doc1")
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[3], new_counts=[1])
|
||||
|
||||
idx.index(chunks=[chunk], indexing_metadata=metadata)
|
||||
|
||||
assert call_order == ["delete", "bulk_index"]
|
||||
|
||||
|
||||
def test_index_delete_called_once_per_doc() -> None:
|
||||
"""Delete is called only once per document, even with multiple chunks."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents.return_value = None
|
||||
|
||||
idx = _make_os_index(mock_client)
|
||||
|
||||
with patch.object(idx, "delete", return_value=0) as mock_delete:
|
||||
# 3 chunks, all same doc — should only delete once
|
||||
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(3)]
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[3])
|
||||
|
||||
idx.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
mock_delete.assert_called_once()
|
||||
|
||||
|
||||
def test_index_flushes_on_doc_boundary() -> None:
|
||||
"""When doc ID changes in the stream, the previous doc's chunks are flushed."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents.return_value = None
|
||||
|
||||
idx = _make_os_index(mock_client)
|
||||
|
||||
bulk_call_chunk_counts: list[int] = []
|
||||
|
||||
def track_bulk(documents: list[object], **_kwargs: object) -> None:
|
||||
bulk_call_chunk_counts.append(len(documents))
|
||||
|
||||
mock_client.bulk_index_documents.side_effect = track_bulk
|
||||
|
||||
with patch.object(idx, "delete", return_value=0):
|
||||
chunks = [
|
||||
_make_chunk("doc1", chunk_id=0),
|
||||
_make_chunk("doc1", chunk_id=1),
|
||||
_make_chunk("doc1", chunk_id=2),
|
||||
_make_chunk("doc2", chunk_id=0),
|
||||
_make_chunk("doc2", chunk_id=1),
|
||||
]
|
||||
metadata = _make_indexing_metadata(
|
||||
["doc1", "doc2"], old_counts=[0, 0], new_counts=[3, 2]
|
||||
)
|
||||
|
||||
idx.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
# First flush: 3 chunks for doc1, second flush: 2 chunks for doc2
|
||||
assert bulk_call_chunk_counts == [3, 2]
|
||||
@@ -1,373 +0,0 @@
|
||||
"""Unit tests for VespaDocumentIndex.index().
|
||||
|
||||
These tests mock all external I/O (HTTP calls, thread pools) and verify
|
||||
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
|
||||
construction.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int = 0,
|
||||
content: str = "test content",
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_doc",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.NOT_APPLICABLE,
|
||||
metadata={},
|
||||
)
|
||||
index_chunk = IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=index_chunk,
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
|
||||
def _make_indexing_metadata(
|
||||
doc_ids: list[str],
|
||||
old_counts: list[int],
|
||||
new_counts: list[int],
|
||||
) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old,
|
||||
new_chunk_cnt=new,
|
||||
)
|
||||
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _stub_enrich(
|
||||
doc_id: str,
|
||||
old_chunk_cnt: int,
|
||||
) -> EnrichedDocumentIndexingInfo:
|
||||
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
|
||||
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
|
||||
return EnrichedDocumentIndexingInfo(
|
||||
doc_id=doc_id,
|
||||
chunk_start_index=0,
|
||||
old_version=False,
|
||||
chunk_end_index=old_chunk_cnt,
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
def test_index_single_new_doc(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""Indexing a single new document returns one record with already_existed=False."""
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunk = _make_chunk("doc1")
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[1])
|
||||
|
||||
results = index.index(chunks=[chunk], indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == "doc1"
|
||||
assert results[0].already_existed is False
|
||||
|
||||
# batch_index_vespa_chunks should be called once with a single cleaned chunk
|
||||
mock_batch_index.assert_called_once()
|
||||
call_kwargs = mock_batch_index.call_args
|
||||
indexed_chunks = call_kwargs.kwargs["chunks"]
|
||||
assert len(indexed_chunks) == 1
|
||||
assert indexed_chunks[0].source_document.id == "doc1"
|
||||
assert call_kwargs.kwargs["index_name"] == "test_index"
|
||||
assert call_kwargs.kwargs["multitenant"] is False
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
def test_index_existing_doc_already_existed_true(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock,
|
||||
mock_delete: MagicMock,
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""Re-indexing a doc with previous chunks deletes old chunks, indexes
|
||||
new ones, and returns already_existed=True."""
|
||||
fake_chunk_ids = [uuid4(), uuid4()]
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=5)
|
||||
mock_get_chunk_ids.return_value = fake_chunk_ids
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunk = _make_chunk("doc1")
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[5], new_counts=[1])
|
||||
|
||||
results = index.index(chunks=[chunk], indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].already_existed is True
|
||||
|
||||
# Old chunks should be deleted
|
||||
mock_delete.assert_called_once()
|
||||
delete_kwargs = mock_delete.call_args.kwargs
|
||||
assert delete_kwargs["doc_chunk_ids"] == fake_chunk_ids
|
||||
assert delete_kwargs["index_name"] == "test_index"
|
||||
|
||||
# New chunk should be indexed
|
||||
mock_batch_index.assert_called_once()
|
||||
indexed_chunks = mock_batch_index.call_args.kwargs["chunks"]
|
||||
assert len(indexed_chunks) == 1
|
||||
assert indexed_chunks[0].source_document.id == "doc1"
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
def test_index_multiple_docs(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""Indexing multiple documents returns one record per unique document."""
|
||||
mock_enrich.side_effect = [
|
||||
_stub_enrich("doc1", old_chunk_cnt=0),
|
||||
_stub_enrich("doc2", old_chunk_cnt=3),
|
||||
]
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunks = [
|
||||
_make_chunk("doc1", chunk_id=0),
|
||||
_make_chunk("doc1", chunk_id=1),
|
||||
_make_chunk("doc2", chunk_id=0),
|
||||
]
|
||||
metadata = _make_indexing_metadata(
|
||||
["doc1", "doc2"], old_counts=[0, 3], new_counts=[2, 1]
|
||||
)
|
||||
|
||||
results = index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
result_map = {r.document_id: r.already_existed for r in results}
|
||||
assert len(result_map) == 2
|
||||
assert result_map["doc1"] is False
|
||||
assert result_map["doc2"] is True
|
||||
|
||||
# All 3 chunks fit in one batch (BATCH_SIZE=128), so one call
|
||||
mock_batch_index.assert_called_once()
|
||||
indexed_chunks = mock_batch_index.call_args.kwargs["chunks"]
|
||||
assert len(indexed_chunks) == 3
|
||||
indexed_doc_ids = [c.source_document.id for c in indexed_chunks]
|
||||
assert indexed_doc_ids == ["doc1", "doc1", "doc2"]
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
def test_index_cleans_doc_ids(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""Documents with invalid Vespa characters get cleaned IDs, but
|
||||
the returned DocumentInsertionRecord uses the original ID."""
|
||||
doc_id_with_quote = "doc'1"
|
||||
mock_enrich.return_value = _stub_enrich(doc_id_with_quote, old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunk = _make_chunk(doc_id_with_quote)
|
||||
metadata = _make_indexing_metadata(
|
||||
[doc_id_with_quote], old_counts=[0], new_counts=[1]
|
||||
)
|
||||
|
||||
results = index.index(chunks=[chunk], indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
# The returned ID should be the original (unclean) ID
|
||||
assert results[0].document_id == doc_id_with_quote
|
||||
|
||||
# The chunk passed to batch_index_vespa_chunks should have the cleaned ID
|
||||
indexed_chunks = mock_batch_index.call_args.kwargs["chunks"]
|
||||
assert len(indexed_chunks) == 1
|
||||
assert indexed_chunks[0].source_document.id == "doc_1" # quote replaced with _
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
def test_index_deduplicates_doc_ids_in_results(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""Multiple chunks from the same document produce only one
|
||||
DocumentInsertionRecord."""
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(5)]
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[5])
|
||||
|
||||
results = index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == "doc1"
|
||||
|
||||
# All 5 chunks should be passed to batch_index_vespa_chunks
|
||||
mock_batch_index.assert_called_once()
|
||||
indexed_chunks = mock_batch_index.call_args.kwargs["chunks"]
|
||||
assert len(indexed_chunks) == 5
|
||||
assert all(c.source_document.id == "doc1" for c in indexed_chunks)
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
|
||||
3,
|
||||
)
|
||||
def test_index_respects_batch_size(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
|
||||
multiple times with correctly sized batches."""
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
|
||||
|
||||
results = index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
|
||||
assert mock_batch_index.call_count == 3
|
||||
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
|
||||
assert batch_sizes == [3, 3, 1]
|
||||
|
||||
# Verify all chunks are accounted for and in order
|
||||
all_indexed = [
|
||||
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
|
||||
]
|
||||
assert len(all_indexed) == 7
|
||||
assert [c.chunk_id for c in all_indexed] == list(range(7))
|
||||
0
backend/tests/unit/onyx/file_processing/__init__.py
Normal file
0
backend/tests/unit/onyx/file_processing/__init__.py
Normal file
45
backend/tests/unit/onyx/file_processing/fixtures/empty.pdf
Normal file
45
backend/tests/unit/onyx/file_processing/fixtures/empty.pdf
Normal file
@@ -0,0 +1,45 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Producer (pypdf)
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Count 1
|
||||
/Kids [ 4 0 R ]
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
xref
|
||||
0 5
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000054 00000 n
|
||||
0000000113 00000 n
|
||||
0000000162 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 5
|
||||
/Root 3 0 R
|
||||
/Info 1 0 R
|
||||
>>
|
||||
startxref
|
||||
256
|
||||
%%EOF
|
||||
BIN
backend/tests/unit/onyx/file_processing/fixtures/encrypted.pdf
Normal file
BIN
backend/tests/unit/onyx/file_processing/fixtures/encrypted.pdf
Normal file
Binary file not shown.
@@ -0,0 +1,89 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Producer (pypdf)
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Count 2
|
||||
/Kids [ 4 0 R 6 0 R ]
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Contents 5 0 R
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
5 0 obj
|
||||
<<
|
||||
/Length 47
|
||||
>>
|
||||
stream
|
||||
BT /F1 12 Tf 50 150 Td (Page one content) Tj ET
|
||||
endstream
|
||||
endobj
|
||||
6 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Contents 7 0 R
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
7 0 obj
|
||||
<<
|
||||
/Length 47
|
||||
>>
|
||||
stream
|
||||
BT /F1 12 Tf 50 150 Td (Page two content) Tj ET
|
||||
endstream
|
||||
endobj
|
||||
xref
|
||||
0 8
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000054 00000 n
|
||||
0000000119 00000 n
|
||||
0000000168 00000 n
|
||||
0000000349 00000 n
|
||||
0000000446 00000 n
|
||||
0000000627 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 8
|
||||
/Root 3 0 R
|
||||
/Info 1 0 R
|
||||
>>
|
||||
startxref
|
||||
724
|
||||
%%EOF
|
||||
62
backend/tests/unit/onyx/file_processing/fixtures/simple.pdf
Normal file
62
backend/tests/unit/onyx/file_processing/fixtures/simple.pdf
Normal file
@@ -0,0 +1,62 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Producer (pypdf)
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Count 1
|
||||
/Kids [ 4 0 R ]
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Contents 5 0 R
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
5 0 obj
|
||||
<<
|
||||
/Length 42
|
||||
>>
|
||||
stream
|
||||
BT /F1 12 Tf 50 150 Td (Hello World) Tj ET
|
||||
endstream
|
||||
endobj
|
||||
xref
|
||||
0 6
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000054 00000 n
|
||||
0000000113 00000 n
|
||||
0000000162 00000 n
|
||||
0000000343 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 6
|
||||
/Root 3 0 R
|
||||
/Info 1 0 R
|
||||
>>
|
||||
startxref
|
||||
435
|
||||
%%EOF
|
||||
BIN
backend/tests/unit/onyx/file_processing/fixtures/with_image.pdf
Normal file
BIN
backend/tests/unit/onyx/file_processing/fixtures/with_image.pdf
Normal file
Binary file not shown.
@@ -0,0 +1,64 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Producer (pypdf)
|
||||
/Title (My Title)
|
||||
/Author (Jane Doe)
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Count 1
|
||||
/Kids [ 4 0 R ]
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Contents 5 0 R
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
5 0 obj
|
||||
<<
|
||||
/Length 35
|
||||
>>
|
||||
stream
|
||||
BT /F1 12 Tf 50 150 Td (test) Tj ET
|
||||
endstream
|
||||
endobj
|
||||
xref
|
||||
0 6
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000091 00000 n
|
||||
0000000150 00000 n
|
||||
0000000199 00000 n
|
||||
0000000380 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 6
|
||||
/Root 3 0 R
|
||||
/Info 1 0 R
|
||||
>>
|
||||
startxref
|
||||
465
|
||||
%%EOF
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Unit tests for image summarization error handling.
|
||||
|
||||
Verifies that:
|
||||
1. LLM errors produce actionable error messages (not base64 dumps)
|
||||
2. Unsupported MIME type logs include the magic bytes and size
|
||||
3. The ValueError raised on LLM failure preserves the original exception
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_processing.image_summarization import _summarize_image
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
|
||||
|
||||
class TestSummarizeImageErrorMessage:
|
||||
"""_summarize_image must not dump base64 image data into error messages."""
|
||||
|
||||
def test_error_message_contains_exception_type_not_base64(self) -> None:
|
||||
"""The ValueError should contain the original exception info, not message payloads."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = RuntimeError("Connection timeout")
|
||||
|
||||
# A fake base64-encoded image string (should NOT appear in the error)
|
||||
fake_encoded = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg..."
|
||||
|
||||
with pytest.raises(ValueError, match="RuntimeError: Connection timeout"):
|
||||
_summarize_image(fake_encoded, mock_llm, query="test")
|
||||
|
||||
def test_error_message_does_not_contain_base64(self) -> None:
|
||||
"""Ensure base64 data is never included in the error message."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = RuntimeError("API error")
|
||||
|
||||
fake_encoded = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image(fake_encoded, mock_llm)
|
||||
|
||||
error_str = str(exc_info.value)
|
||||
assert "base64" not in error_str
|
||||
assert "iVBOR" not in error_str
|
||||
|
||||
def test_original_exception_is_chained(self) -> None:
|
||||
"""The ValueError should chain the original exception via __cause__."""
|
||||
mock_llm = MagicMock()
|
||||
original = RuntimeError("upstream failure")
|
||||
mock_llm.invoke.side_effect = original
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
assert exc_info.value.__cause__ is original
|
||||
|
||||
|
||||
class TestUnsupportedMimeTypeLogging:
|
||||
"""summarize_image_with_error_handling should log useful info for unsupported formats."""
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.image_summarization.summarize_image_pipeline",
|
||||
side_effect=__import__(
|
||||
"onyx.file_processing.image_summarization",
|
||||
fromlist=["UnsupportedImageFormatError"],
|
||||
).UnsupportedImageFormatError("unsupported"),
|
||||
)
|
||||
def test_logs_magic_bytes_and_size(
|
||||
self, mock_pipeline: MagicMock # noqa: ARG002
|
||||
) -> None:
|
||||
"""The info log should include magic bytes hex and image size."""
|
||||
mock_llm = MagicMock()
|
||||
# TIFF magic bytes (not in the supported list)
|
||||
image_data = b"\x49\x49\x2a\x00" + b"\x00" * 100
|
||||
|
||||
with patch("onyx.file_processing.image_summarization.logger") as mock_logger:
|
||||
result = summarize_image_with_error_handling(
|
||||
llm=mock_llm,
|
||||
image_data=image_data,
|
||||
context_name="test_image.tiff",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_logger.info.assert_called_once()
|
||||
log_args = mock_logger.info.call_args
|
||||
# Check the format string args contain magic bytes and size
|
||||
assert "49492a00" in str(log_args)
|
||||
assert "104" in str(log_args) # 4 + 100 bytes
|
||||
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Unit tests verifying that LiteLLM error details are extracted and surfaced
|
||||
in image summarization error messages.
|
||||
|
||||
When the LLM call fails, the error handler should include the status_code,
|
||||
llm_provider, and model from LiteLLM exceptions so operators can diagnose
|
||||
the root cause (rate limit, content filter, unsupported vision, etc.)
|
||||
without needing to dig through LiteLLM internals.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_processing.image_summarization import _summarize_image
|
||||
|
||||
|
||||
def _make_litellm_style_error(
|
||||
*,
|
||||
message: str = "API error",
|
||||
status_code: int | None = None,
|
||||
llm_provider: str | None = None,
|
||||
model: str | None = None,
|
||||
) -> RuntimeError:
|
||||
"""Create an exception with LiteLLM-style attributes."""
|
||||
exc = RuntimeError(message)
|
||||
if status_code is not None:
|
||||
exc.status_code = status_code # type: ignore[attr-defined]
|
||||
if llm_provider is not None:
|
||||
exc.llm_provider = llm_provider # type: ignore[attr-defined]
|
||||
if model is not None:
|
||||
exc.model = model # type: ignore[attr-defined]
|
||||
return exc
|
||||
|
||||
|
||||
class TestLiteLLMErrorExtraction:
|
||||
"""Verify that LiteLLM error attributes are included in the ValueError."""
|
||||
|
||||
def test_status_code_included(self) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = _make_litellm_style_error(
|
||||
message="Content filter triggered",
|
||||
status_code=400,
|
||||
llm_provider="azure",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="status_code=400"):
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
def test_llm_provider_included(self) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = _make_litellm_style_error(
|
||||
message="Bad request",
|
||||
status_code=400,
|
||||
llm_provider="azure",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="llm_provider=azure"):
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
def test_model_included(self) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = _make_litellm_style_error(
|
||||
message="Bad request",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="model=gpt-4o"):
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
def test_all_fields_in_single_message(self) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = _make_litellm_style_error(
|
||||
message="Rate limit exceeded",
|
||||
status_code=429,
|
||||
llm_provider="azure",
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "status_code=429" in msg
|
||||
assert "llm_provider=azure" in msg
|
||||
assert "model=gpt-4o" in msg
|
||||
assert "Rate limit exceeded" in msg
|
||||
|
||||
def test_plain_exception_without_litellm_attrs(self) -> None:
|
||||
"""Non-LiteLLM exceptions should still produce a useful message."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke.side_effect = ConnectionError("Connection refused")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "ConnectionError" in msg
|
||||
assert "Connection refused" in msg
|
||||
# Should not contain status_code/llm_provider/model
|
||||
assert "status_code" not in msg
|
||||
assert "llm_provider" not in msg
|
||||
|
||||
def test_no_base64_in_error(self) -> None:
|
||||
"""Error messages must not contain the full base64 image payload.
|
||||
|
||||
Some LiteLLM exceptions echo the request body (including base64 images)
|
||||
in their message. The truncation guard ensures the bulk of such a
|
||||
payload is stripped from the re-raised ValueError.
|
||||
"""
|
||||
mock_llm = MagicMock()
|
||||
# Build a long base64-like payload that exceeds the 512-char truncation
|
||||
fake_b64_payload = "iVBORw0KGgo" * 100 # ~1100 chars
|
||||
fake_b64 = f"data:image/png;base64,{fake_b64_payload}"
|
||||
|
||||
mock_llm.invoke.side_effect = RuntimeError(
|
||||
f"Request failed for payload: {fake_b64}"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image(fake_b64, mock_llm)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
# The full payload must not appear (truncation should have kicked in)
|
||||
assert fake_b64_payload not in msg
|
||||
assert "truncated" in msg
|
||||
|
||||
def test_long_error_message_truncated(self) -> None:
|
||||
"""Exception messages longer than 512 chars are truncated."""
|
||||
mock_llm = MagicMock()
|
||||
long_msg = "x" * 1000
|
||||
mock_llm.invoke.side_effect = RuntimeError(long_msg)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_summarize_image("data:image/png;base64,abc", mock_llm)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "truncated" in msg
|
||||
# The full 1000-char string should not appear
|
||||
assert long_msg not in msg
|
||||
124
backend/tests/unit/onyx/file_processing/test_pdf.py
Normal file
124
backend/tests/unit/onyx/file_processing/test_pdf.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Unit tests for pypdf-dependent PDF processing functions.
|
||||
|
||||
Tests cover:
|
||||
- read_pdf_file: text extraction, metadata, encrypted PDFs, image extraction
|
||||
- pdf_to_text: convenience wrapper
|
||||
- is_pdf_protected: password protection detection
|
||||
|
||||
Fixture PDFs live in ./fixtures/ and are pre-built so the test layer has no
|
||||
dependency on pypdf internals (pypdf.generic).
|
||||
"""
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from onyx.file_processing.extract_file_text import pdf_to_text
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.password_validation import is_pdf_protected
|
||||
|
||||
FIXTURES = Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
def _load(name: str) -> BytesIO:
|
||||
return BytesIO((FIXTURES / name).read_bytes())
|
||||
|
||||
|
||||
# ── read_pdf_file ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestReadPdfFile:
|
||||
def test_basic_text_extraction(self) -> None:
|
||||
text, _, images = read_pdf_file(_load("simple.pdf"))
|
||||
assert "Hello World" in text
|
||||
assert images == []
|
||||
|
||||
def test_multi_page_text_extraction(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("multipage.pdf"))
|
||||
assert "Page one content" in text
|
||||
assert "Page two content" in text
|
||||
|
||||
def test_metadata_extraction(self) -> None:
|
||||
_, pdf_metadata, _ = read_pdf_file(_load("with_metadata.pdf"))
|
||||
assert pdf_metadata.get("Title") == "My Title"
|
||||
assert pdf_metadata.get("Author") == "Jane Doe"
|
||||
|
||||
def test_encrypted_pdf_with_correct_password(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="pass123")
|
||||
assert "Secret Content" in text
|
||||
|
||||
def test_encrypted_pdf_without_password(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("encrypted.pdf"))
|
||||
assert text == ""
|
||||
|
||||
def test_encrypted_pdf_with_wrong_password(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
|
||||
assert text == ""
|
||||
|
||||
def test_empty_pdf(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("empty.pdf"))
|
||||
assert text.strip() == ""
|
||||
|
||||
def test_invalid_pdf_returns_empty(self) -> None:
|
||||
text, _, images = read_pdf_file(BytesIO(b"this is not a pdf"))
|
||||
assert text == ""
|
||||
assert images == []
|
||||
|
||||
def test_image_extraction_disabled_by_default(self) -> None:
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"))
|
||||
assert images == []
|
||||
|
||||
def test_image_extraction_collects_images(self) -> None:
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert len(images) == 1
|
||||
img_bytes, img_name = images[0]
|
||||
assert len(img_bytes) > 0
|
||||
assert img_name # non-empty name
|
||||
|
||||
def test_image_callback_streams_instead_of_collecting(self) -> None:
|
||||
"""With image_callback, images are streamed via callback and not accumulated."""
|
||||
collected: list[tuple[bytes, str]] = []
|
||||
|
||||
def callback(data: bytes, name: str) -> None:
|
||||
collected.append((data, name))
|
||||
|
||||
_, _, images = read_pdf_file(
|
||||
_load("with_image.pdf"), extract_images=True, image_callback=callback
|
||||
)
|
||||
# Callback received the image
|
||||
assert len(collected) == 1
|
||||
assert len(collected[0][0]) > 0
|
||||
# Returned list is empty when callback is used
|
||||
assert images == []
|
||||
|
||||
|
||||
# ── pdf_to_text ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPdfToText:
|
||||
def test_returns_text(self) -> None:
|
||||
assert "Hello World" in pdf_to_text(_load("simple.pdf"))
|
||||
|
||||
def test_with_password(self) -> None:
|
||||
assert "Secret Content" in pdf_to_text(
|
||||
_load("encrypted.pdf"), pdf_pass="pass123"
|
||||
)
|
||||
|
||||
def test_encrypted_without_password_returns_empty(self) -> None:
|
||||
assert pdf_to_text(_load("encrypted.pdf")) == ""
|
||||
|
||||
|
||||
# ── is_pdf_protected ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsPdfProtected:
|
||||
def test_unprotected_pdf(self) -> None:
|
||||
assert is_pdf_protected(_load("simple.pdf")) is False
|
||||
|
||||
def test_protected_pdf(self) -> None:
|
||||
assert is_pdf_protected(_load("encrypted.pdf")) is True
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("simple.pdf")
|
||||
pdf.seek(42)
|
||||
is_pdf_protected(pdf)
|
||||
assert pdf.tell() == 42
|
||||
22
backend/tests/unit/onyx/hooks/test_base_spec.py
Normal file
22
backend/tests/unit/onyx/hooks/test_base_spec.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
def test_init_subclass_raises_for_missing_attrs() -> None:
|
||||
with pytest.raises(TypeError, match="must define class attributes"):
|
||||
|
||||
class IncompleteSpec(HookPointSpec):
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
# missing display_name, description, etc.
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
86
backend/tests/unit/onyx/hooks/test_models.py
Normal file
86
backend/tests/unit/onyx/hooks/test_models.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.models import HookCreateRequest
|
||||
from onyx.hooks.models import HookUpdateRequest
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_empty() -> None:
|
||||
# No fields supplied at all
|
||||
with pytest.raises(ValidationError, match="At least one field must be provided"):
|
||||
HookUpdateRequest()
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_null_name_when_only_field() -> None:
|
||||
# Explicitly setting name=None is rejected as name cannot be cleared
|
||||
with pytest.raises(ValidationError, match="name cannot be cleared"):
|
||||
HookUpdateRequest(name=None)
|
||||
|
||||
|
||||
def test_hook_update_request_accepts_single_field() -> None:
|
||||
req = HookUpdateRequest(name="new name")
|
||||
assert req.name == "new name"
|
||||
|
||||
|
||||
def test_hook_update_request_accepts_partial_fields() -> None:
|
||||
req = HookUpdateRequest(fail_strategy=HookFailStrategy.SOFT, timeout_seconds=10.0)
|
||||
assert req.fail_strategy == HookFailStrategy.SOFT
|
||||
assert req.timeout_seconds == 10.0
|
||||
assert req.name is None
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_null_name() -> None:
|
||||
with pytest.raises(ValidationError, match="name cannot be cleared"):
|
||||
HookUpdateRequest(name=None, fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_empty_name() -> None:
|
||||
with pytest.raises(ValidationError, match="name cannot be cleared"):
|
||||
HookUpdateRequest(name="", fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_null_endpoint_url() -> None:
|
||||
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
|
||||
HookUpdateRequest(endpoint_url=None, fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_empty_endpoint_url() -> None:
|
||||
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
|
||||
HookUpdateRequest(endpoint_url="", fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_allows_null_api_key() -> None:
|
||||
# api_key=null is valid — means "clear the api key"
|
||||
req = HookUpdateRequest(api_key=None)
|
||||
assert req.api_key is None
|
||||
assert "api_key" in req.model_fields_set
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_whitespace_name() -> None:
|
||||
with pytest.raises(ValidationError, match="name cannot be cleared"):
|
||||
HookUpdateRequest(name=" ", fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_whitespace_endpoint_url() -> None:
|
||||
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
|
||||
HookUpdateRequest(endpoint_url=" ", fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_create_request_rejects_whitespace_name() -> None:
|
||||
with pytest.raises(ValidationError, match="whitespace-only"):
|
||||
HookCreateRequest(
|
||||
name=" ",
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
endpoint_url="https://example.com/hook",
|
||||
)
|
||||
|
||||
|
||||
def test_hook_create_request_rejects_whitespace_endpoint_url() -> None:
|
||||
with pytest.raises(ValidationError, match="whitespace-only"):
|
||||
HookCreateRequest(
|
||||
name="my hook",
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
endpoint_url=" ",
|
||||
)
|
||||
60
backend/tests/unit/onyx/hooks/test_query_processing_spec.py
Normal file
60
backend/tests/unit/onyx/hooks/test_query_processing_spec.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.query_processing import QueryProcessingSpec
|
||||
|
||||
|
||||
def test_hook_point_is_query_processing() -> None:
|
||||
assert QueryProcessingSpec().hook_point == HookPoint.QUERY_PROCESSING
|
||||
|
||||
|
||||
def test_default_fail_strategy_is_hard() -> None:
|
||||
assert QueryProcessingSpec().default_fail_strategy == HookFailStrategy.HARD
|
||||
|
||||
|
||||
def test_default_timeout_seconds() -> None:
|
||||
# User is actively waiting — 5s is the documented contract for this hook point
|
||||
assert QueryProcessingSpec().default_timeout_seconds == 5.0
|
||||
|
||||
|
||||
def test_input_schema_required_fields() -> None:
|
||||
schema = QueryProcessingSpec().input_schema
|
||||
assert schema["type"] == "object"
|
||||
required = schema["required"]
|
||||
assert "query" in required
|
||||
assert "user_email" in required
|
||||
assert "chat_session_id" in required
|
||||
|
||||
|
||||
def test_input_schema_chat_session_id_is_string() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
assert props["chat_session_id"]["type"] == "string"
|
||||
|
||||
|
||||
def test_input_schema_query_is_string() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
assert props["query"]["type"] == "string"
|
||||
|
||||
|
||||
def test_input_schema_user_email_is_nullable() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
assert "null" in props["user_email"]["type"]
|
||||
|
||||
|
||||
def test_output_schema_query_is_required() -> None:
|
||||
schema = QueryProcessingSpec().output_schema
|
||||
assert "query" in schema["required"]
|
||||
|
||||
|
||||
def test_output_schema_query_is_nullable() -> None:
|
||||
# null means "reject the query"
|
||||
props = QueryProcessingSpec().output_schema["properties"]
|
||||
assert "null" in props["query"]["type"]
|
||||
|
||||
|
||||
def test_output_schema_rejection_message_is_optional() -> None:
|
||||
schema = QueryProcessingSpec().output_schema
|
||||
assert "rejection_message" not in schema.get("required", [])
|
||||
|
||||
|
||||
def test_input_schema_no_additional_properties() -> None:
|
||||
assert QueryProcessingSpec().input_schema.get("additionalProperties") is False
|
||||
47
backend/tests/unit/onyx/hooks/test_registry.py
Normal file
47
backend/tests/unit/onyx/hooks/test_registry.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks import registry as registry_module
|
||||
from onyx.hooks.registry import get_all_specs
|
||||
from onyx.hooks.registry import get_hook_point_spec
|
||||
from onyx.hooks.registry import validate_registry
|
||||
|
||||
|
||||
def test_registry_covers_all_hook_points() -> None:
|
||||
"""Every HookPoint enum member must have a registered spec."""
|
||||
assert {s.hook_point for s in get_all_specs()} == set(
|
||||
HookPoint
|
||||
), f"Missing specs for: {set(HookPoint) - {s.hook_point for s in get_all_specs()}}"
|
||||
|
||||
|
||||
def test_get_hook_point_spec_returns_correct_spec() -> None:
|
||||
for hook_point in HookPoint:
|
||||
spec = get_hook_point_spec(hook_point)
|
||||
assert spec.hook_point == hook_point
|
||||
|
||||
|
||||
def test_get_all_specs_returns_all() -> None:
|
||||
specs = get_all_specs()
|
||||
assert len(specs) == len(HookPoint)
|
||||
assert {s.hook_point for s in specs} == set(HookPoint)
|
||||
|
||||
|
||||
def test_get_hook_point_spec_raises_for_unregistered(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""get_hook_point_spec raises ValueError when a hook point has no spec."""
|
||||
monkeypatch.setattr(registry_module, "_REGISTRY", {})
|
||||
with pytest.raises(ValueError, match="No spec registered for hook point"):
|
||||
get_hook_point_spec(HookPoint.QUERY_PROCESSING)
|
||||
|
||||
|
||||
def test_validate_registry_passes() -> None:
|
||||
validate_registry() # should not raise with the real registry
|
||||
|
||||
|
||||
def test_validate_registry_raises_for_incomplete(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(registry_module, "_REGISTRY", {})
|
||||
with pytest.raises(RuntimeError, match="Hook point\\(s\\) have no registered spec"):
|
||||
validate_registry()
|
||||
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Unit tests for vision model selection logging in get_default_llm_with_vision.
|
||||
|
||||
Verifies that operators get clear feedback about:
|
||||
1. Which vision model was selected and why
|
||||
2. When the default vision model doesn't support image input
|
||||
3. When no vision-capable model exists at all
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
|
||||
|
||||
_FACTORY = "onyx.llm.factory"
|
||||
|
||||
|
||||
def _make_mock_model(
|
||||
*,
|
||||
name: str = "gpt-4o",
|
||||
provider: str = "openai",
|
||||
provider_id: int = 1,
|
||||
flow_types: list[str] | None = None,
|
||||
) -> MagicMock:
|
||||
model = MagicMock()
|
||||
model.name = name
|
||||
model.llm_provider_id = provider_id
|
||||
model.llm_provider.provider = provider
|
||||
model.llm_model_flow_types = flow_types or []
|
||||
return model
|
||||
|
||||
|
||||
@patch(f"{_FACTORY}.get_session_with_current_tenant")
|
||||
@patch(f"{_FACTORY}.fetch_default_vision_model")
|
||||
@patch(f"{_FACTORY}.model_supports_image_input", return_value=True)
|
||||
@patch(f"{_FACTORY}.llm_from_provider")
|
||||
@patch(f"{_FACTORY}.LLMProviderView")
|
||||
@patch(f"{_FACTORY}.logger")
|
||||
def test_logs_when_using_default_vision_model(
|
||||
mock_logger: MagicMock,
|
||||
mock_provider_view: MagicMock, # noqa: ARG001
|
||||
mock_llm_from: MagicMock, # noqa: ARG001
|
||||
mock_supports: MagicMock, # noqa: ARG001
|
||||
mock_fetch_default: MagicMock,
|
||||
mock_session: MagicMock, # noqa: ARG001
|
||||
) -> None:
|
||||
mock_fetch_default.return_value = _make_mock_model(name="gpt-4o", provider="azure")
|
||||
|
||||
get_default_llm_with_vision()
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
log_msg = mock_logger.info.call_args[0][0]
|
||||
assert "default vision model" in log_msg.lower()
|
||||
|
||||
|
||||
@patch(f"{_FACTORY}.get_session_with_current_tenant")
|
||||
@patch(f"{_FACTORY}.fetch_default_vision_model")
|
||||
@patch(f"{_FACTORY}.model_supports_image_input", return_value=False)
|
||||
@patch(f"{_FACTORY}.fetch_existing_models", return_value=[])
|
||||
@patch(f"{_FACTORY}.logger")
|
||||
def test_warns_when_default_model_lacks_vision(
|
||||
mock_logger: MagicMock,
|
||||
mock_fetch_models: MagicMock, # noqa: ARG001
|
||||
mock_supports: MagicMock, # noqa: ARG001
|
||||
mock_fetch_default: MagicMock,
|
||||
mock_session: MagicMock, # noqa: ARG001
|
||||
) -> None:
|
||||
mock_fetch_default.return_value = _make_mock_model(
|
||||
name="text-only-model", provider="azure"
|
||||
)
|
||||
|
||||
result = get_default_llm_with_vision()
|
||||
|
||||
assert result is None
|
||||
# Should have warned about the default model not supporting vision
|
||||
warning_calls = [
|
||||
call
|
||||
for call in mock_logger.warning.call_args_list
|
||||
if "does not support" in str(call)
|
||||
]
|
||||
assert len(warning_calls) >= 1
|
||||
|
||||
|
||||
@patch(f"{_FACTORY}.get_session_with_current_tenant")
|
||||
@patch(f"{_FACTORY}.fetch_default_vision_model", return_value=None)
|
||||
@patch(f"{_FACTORY}.fetch_existing_models", return_value=[])
|
||||
@patch(f"{_FACTORY}.logger")
|
||||
def test_warns_when_no_models_exist(
|
||||
mock_logger: MagicMock,
|
||||
mock_fetch_models: MagicMock, # noqa: ARG001
|
||||
mock_fetch_default: MagicMock, # noqa: ARG001
|
||||
mock_session: MagicMock, # noqa: ARG001
|
||||
) -> None:
|
||||
result = get_default_llm_with_vision()
|
||||
|
||||
assert result is None
|
||||
mock_logger.warning.assert_called_once()
|
||||
log_msg = mock_logger.warning.call_args[0][0]
|
||||
assert "no llm models" in log_msg.lower()
|
||||
|
||||
|
||||
@patch(f"{_FACTORY}.get_session_with_current_tenant")
|
||||
@patch(f"{_FACTORY}.fetch_default_vision_model", return_value=None)
|
||||
@patch(f"{_FACTORY}.fetch_existing_models")
|
||||
@patch(f"{_FACTORY}.model_supports_image_input", return_value=False)
|
||||
@patch(f"{_FACTORY}.LLMProviderView")
|
||||
@patch(f"{_FACTORY}.logger")
|
||||
def test_warns_when_no_model_supports_vision(
|
||||
mock_logger: MagicMock,
|
||||
mock_provider_view: MagicMock, # noqa: ARG001
|
||||
mock_supports: MagicMock, # noqa: ARG001
|
||||
mock_fetch_models: MagicMock,
|
||||
mock_fetch_default: MagicMock, # noqa: ARG001
|
||||
mock_session: MagicMock, # noqa: ARG001
|
||||
) -> None:
|
||||
mock_fetch_models.return_value = [
|
||||
_make_mock_model(name="text-model-1", provider="openai"),
|
||||
_make_mock_model(name="text-model-2", provider="azure", provider_id=2),
|
||||
]
|
||||
|
||||
result = get_default_llm_with_vision()
|
||||
|
||||
assert result is None
|
||||
warning_calls = [
|
||||
call
|
||||
for call in mock_logger.warning.call_args_list
|
||||
if "no vision-capable model" in str(call).lower()
|
||||
]
|
||||
assert len(warning_calls) == 1
|
||||
109
backend/tests/unit/onyx/server/test_upload_files.py
Normal file
109
backend/tests/unit/onyx/server/test_upload_files.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import io
|
||||
import zipfile
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.server.documents.connector import upload_files
|
||||
|
||||
|
||||
def _create_test_zip() -> bytes:
|
||||
"""Create a simple in-memory zip file containing two text files."""
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("file1.txt", "hello")
|
||||
zf.writestr("file2.txt", "world")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _make_upload_file(content: bytes, filename: str, content_type: str) -> UploadFile:
|
||||
return UploadFile(
|
||||
file=io.BytesIO(content),
|
||||
filename=filename,
|
||||
headers=Headers({"content-type": content_type}),
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.server.documents.connector.get_default_file_store")
|
||||
def test_upload_zip_with_unzip_true_extracts_files(
|
||||
mock_get_store: MagicMock,
|
||||
) -> None:
|
||||
"""When unzip=True (default), a zip upload is extracted into individual files."""
|
||||
mock_store = MagicMock()
|
||||
mock_store.save_file.side_effect = lambda **kwargs: f"id-{kwargs['display_name']}"
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
zip_bytes = _create_test_zip()
|
||||
upload = _make_upload_file(zip_bytes, "test.zip", "application/zip")
|
||||
|
||||
result = upload_files([upload], FileOrigin.CONNECTOR)
|
||||
|
||||
# Should have extracted the two individual files, not stored the zip itself
|
||||
assert len(result.file_paths) == 2
|
||||
assert "id-file1.txt" in result.file_paths
|
||||
assert "id-file2.txt" in result.file_paths
|
||||
assert "file1.txt" in result.file_names
|
||||
assert "file2.txt" in result.file_names
|
||||
|
||||
|
||||
@patch("onyx.server.documents.connector.get_default_file_store")
|
||||
def test_upload_zip_with_unzip_false_stores_zip_as_is(
|
||||
mock_get_store: MagicMock,
|
||||
) -> None:
|
||||
"""When unzip=False, the zip file is stored as-is without extraction."""
|
||||
mock_store = MagicMock()
|
||||
mock_store.save_file.return_value = "zip-file-id"
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
zip_bytes = _create_test_zip()
|
||||
upload = _make_upload_file(zip_bytes, "site_export.zip", "application/zip")
|
||||
|
||||
result = upload_files([upload], FileOrigin.CONNECTOR, unzip=False)
|
||||
|
||||
# Should store exactly one file (the zip itself)
|
||||
assert len(result.file_paths) == 1
|
||||
assert result.file_paths[0] == "zip-file-id"
|
||||
assert result.file_names == ["site_export.zip"]
|
||||
# No zip metadata should be created
|
||||
assert result.zip_metadata_file_id is None
|
||||
|
||||
# Verify the stored content is a valid zip
|
||||
saved_content: io.BytesIO = mock_store.save_file.call_args[1]["content"]
|
||||
saved_content.seek(0)
|
||||
with zipfile.ZipFile(saved_content, "r") as zf:
|
||||
assert set(zf.namelist()) == {"file1.txt", "file2.txt"}
|
||||
|
||||
|
||||
@patch("onyx.server.documents.connector.get_default_file_store")
|
||||
def test_upload_invalid_zip_with_unzip_false_raises(
|
||||
mock_get_store: MagicMock,
|
||||
) -> None:
|
||||
"""An invalid zip is rejected even when unzip=False (validation still runs)."""
|
||||
mock_get_store.return_value = MagicMock()
|
||||
|
||||
bad_zip = _make_upload_file(b"not a zip", "bad.zip", "application/zip")
|
||||
|
||||
with pytest.raises(BadZipFile):
|
||||
upload_files([bad_zip], FileOrigin.CONNECTOR, unzip=False)
|
||||
|
||||
|
||||
@patch("onyx.server.documents.connector.get_default_file_store")
|
||||
def test_upload_multiple_zips_rejected_when_unzip_false(
|
||||
mock_get_store: MagicMock,
|
||||
) -> None:
|
||||
"""The seen_zip guard rejects a second zip even when unzip=False."""
|
||||
mock_store = MagicMock()
|
||||
mock_store.save_file.return_value = "zip-id"
|
||||
mock_get_store.return_value = mock_store
|
||||
|
||||
zip_bytes = _create_test_zip()
|
||||
zip1 = _make_upload_file(zip_bytes, "a.zip", "application/zip")
|
||||
zip2 = _make_upload_file(zip_bytes, "b.zip", "application/zip")
|
||||
|
||||
with pytest.raises(Exception, match="Only one zip file"):
|
||||
upload_files([zip1, zip2], FileOrigin.CONNECTOR, unzip=False)
|
||||
93
examples/widget/package-lock.json
generated
93
examples/widget/package-lock.json
generated
@@ -8,7 +8,7 @@
|
||||
"name": "widget",
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"next": "^16.1.5",
|
||||
"next": "^16.1.7",
|
||||
"react": "^19",
|
||||
"react-dom": "^19",
|
||||
"react-markdown": "^10.1.0"
|
||||
@@ -1023,9 +1023,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.5.tgz",
|
||||
"integrity": "sha512-CRSCPJiSZoi4Pn69RYBDI9R7YK2g59vLexPQFXY0eyw+ILevIenCywzg+DqmlBik9zszEnw2HLFOUlLAcJbL7g==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
|
||||
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
@@ -1039,9 +1039,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.5.tgz",
|
||||
"integrity": "sha512-eK7Wdm3Hjy/SCL7TevlH0C9chrpeOYWx2iR7guJDaz4zEQKWcS1IMVfMb9UKBFMg1XgzcPTYPIp1Vcpukkjg6Q==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
|
||||
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1055,9 +1055,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.5.tgz",
|
||||
"integrity": "sha512-foQscSHD1dCuxBmGkbIr6ScAUF6pRoDZP6czajyvmXPAOFNnQUJu2Os1SGELODjKp/ULa4fulnBWoHV3XdPLfA==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
|
||||
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1071,9 +1071,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.5.tgz",
|
||||
"integrity": "sha512-qNIb42o3C02ccIeSeKjacF3HXotGsxh/FMk/rSRmCzOVMtoWH88odn2uZqF8RLsSUWHcAqTgYmPD3pZ03L9ZAA==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1087,9 +1087,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.5.tgz",
|
||||
"integrity": "sha512-U+kBxGUY1xMAzDTXmuVMfhaWUZQAwzRaHJ/I6ihtR5SbTVUEaDRiEU9YMjy1obBWpdOBuk1bcm+tsmifYSygfw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1103,9 +1103,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.5.tgz",
|
||||
"integrity": "sha512-gq2UtoCpN7Ke/7tKaU7i/1L7eFLfhMbXjNghSv0MVGF1dmuoaPeEVDvkDuO/9LVa44h5gqpWeJ4mRRznjDv7LA==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1119,9 +1119,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.5.tgz",
|
||||
"integrity": "sha512-bQWSE729PbXT6mMklWLf8dotislPle2L70E9q6iwETYEOt092GDn0c+TTNj26AjmeceSsC4ndyGsK5nKqHYXjQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1135,9 +1135,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.5.tgz",
|
||||
"integrity": "sha512-LZli0anutkIllMtTAWZlDqdfvjWX/ch8AFK5WgkNTvaqwlouiD1oHM+WW8RXMiL0+vAkAJyAGEzPPjO+hnrSNQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1151,9 +1151,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.5.tgz",
|
||||
"integrity": "sha512-7is37HJTNQGhjPpQbkKjKEboHYQnCgpVt/4rBrrln0D9nderNxZ8ZWs8w1fAtzUx7wEyYjQ+/13myFgFj6K2Ng==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -2564,12 +2564,15 @@
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/baseline-browser-mapping": {
|
||||
"version": "2.9.14",
|
||||
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.14.tgz",
|
||||
"integrity": "sha512-B0xUquLkiGLgHhpPBqvl7GWegWBUNuujQ6kXd/r1U38ElPT6Ok8KZ8e+FpUGEc2ZoRQUzq/aUnaKFc/svWUGSg==",
|
||||
"version": "2.10.8",
|
||||
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.8.tgz",
|
||||
"integrity": "sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==",
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"baseline-browser-mapping": "dist/cli.js"
|
||||
"baseline-browser-mapping": "dist/cli.cjs"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
@@ -5926,14 +5929,14 @@
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.5.tgz",
|
||||
"integrity": "sha512-f+wE+NSbiQgh3DSAlTaw2FwY5yGdVViAtp8TotNQj4kk4Q8Bh1sC/aL9aH+Rg1YAVn18OYXsRDT7U/079jgP7w==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
|
||||
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/env": "16.1.5",
|
||||
"@next/env": "16.1.7",
|
||||
"@swc/helpers": "0.5.15",
|
||||
"baseline-browser-mapping": "^2.8.3",
|
||||
"baseline-browser-mapping": "^2.9.19",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
"postcss": "8.4.31",
|
||||
"styled-jsx": "5.1.6"
|
||||
@@ -5945,14 +5948,14 @@
|
||||
"node": ">=20.9.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "16.1.5",
|
||||
"@next/swc-darwin-x64": "16.1.5",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.5",
|
||||
"@next/swc-linux-arm64-musl": "16.1.5",
|
||||
"@next/swc-linux-x64-gnu": "16.1.5",
|
||||
"@next/swc-linux-x64-musl": "16.1.5",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.5",
|
||||
"@next/swc-win32-x64-msvc": "16.1.5",
|
||||
"@next/swc-darwin-arm64": "16.1.7",
|
||||
"@next/swc-darwin-x64": "16.1.7",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.7",
|
||||
"@next/swc-linux-arm64-musl": "16.1.7",
|
||||
"@next/swc-linux-x64-gnu": "16.1.7",
|
||||
"@next/swc-linux-x64-musl": "16.1.7",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.7",
|
||||
"@next/swc-win32-x64-msvc": "16.1.7",
|
||||
"sharp": "^0.34.4"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"lint": "next lint"
|
||||
},
|
||||
"dependencies": {
|
||||
"next": "^16.1.5",
|
||||
"next": "^16.1.7",
|
||||
"react": "^19",
|
||||
"react-dom": "^19",
|
||||
"react-markdown": "^10.1.0"
|
||||
|
||||
@@ -92,7 +92,7 @@ backend = [
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypandoc_binary==1.16.2",
|
||||
"pypdf==6.8.0",
|
||||
"pypdf==6.9.1",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
@@ -245,6 +245,7 @@ select = [
|
||||
"ARG",
|
||||
"E",
|
||||
"F",
|
||||
"S324",
|
||||
"W",
|
||||
]
|
||||
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -4481,7 +4481,7 @@ requires-dist = [
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.8.0" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.9.1" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
@@ -5727,11 +5727,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.8.0"
|
||||
version = "6.9.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import "@opal/components/buttons/button/styles.css";
|
||||
import "@opal/components/tooltip.css";
|
||||
import { Interactive, type InteractiveStatelessProps } from "@opal/core";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
@@ -67,7 +66,7 @@ function Button({
|
||||
const labelEl = children ? (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-button-label",
|
||||
"whitespace-nowrap",
|
||||
isLarge ? "font-main-ui-body " : "font-secondary-body",
|
||||
responsiveHideText && "hidden md:inline"
|
||||
)}
|
||||
@@ -87,7 +86,7 @@ function Button({
|
||||
isLarge ? "default" : size === "2xs" ? "mini" : "compact"
|
||||
}
|
||||
>
|
||||
<div className={cn("opal-button interactive-foreground")}>
|
||||
<div className="flex flex-row items-center gap-1 interactive-foreground">
|
||||
{iconWrapper(Icon, size, !!children)}
|
||||
|
||||
{labelEl}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
/* Button — layout only; colors handled by Interactive.Stateless */
|
||||
|
||||
.opal-button {
|
||||
@apply flex flex-row items-center gap-1;
|
||||
}
|
||||
|
||||
.opal-button-label {
|
||||
@apply whitespace-nowrap;
|
||||
}
|
||||
@@ -118,7 +118,7 @@ function OpenButton({
|
||||
const labelEl = children ? (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-button-label whitespace-nowrap",
|
||||
"whitespace-nowrap",
|
||||
isLarge ? "font-main-ui-body" : "font-secondary-body"
|
||||
)}
|
||||
>
|
||||
@@ -143,7 +143,7 @@ function OpenButton({
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"opal-button interactive-foreground flex flex-row items-center",
|
||||
"interactive-foreground flex flex-row items-center",
|
||||
justifyContent === "between" ? "w-full justify-between" : "gap-1",
|
||||
foldable &&
|
||||
justifyContent !== "between" &&
|
||||
|
||||
@@ -174,6 +174,7 @@ function ContentLg({
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={title}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
|
||||
@@ -218,6 +218,7 @@ function ContentMd({
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
)}
|
||||
title={title}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
|
||||
@@ -118,6 +118,7 @@ function ContentSm({
|
||||
<span
|
||||
className={cn("opal-content-sm-title", config.titleFont)}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={title}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
|
||||
@@ -231,6 +231,7 @@ function ContentXl({
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={title}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
|
||||
82
web/package-lock.json
generated
82
web/package-lock.json
generated
@@ -61,7 +61,7 @@
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"mime": "^4.1.0",
|
||||
"motion": "^12.29.0",
|
||||
"next": "16.1.6",
|
||||
"next": "16.1.7",
|
||||
"next-themes": "^0.4.4",
|
||||
"postcss": "^8.5.6",
|
||||
"posthog-js": "^1.176.0",
|
||||
@@ -2896,9 +2896,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.6.tgz",
|
||||
"integrity": "sha512-N1ySLuZjnAtN3kFnwhAwPvZah8RJxKasD7x1f8shFqhncnWZn4JMfg37diLNuoHsLAlrDfM3g4mawVdtAG8XLQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
|
||||
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
@@ -2942,9 +2942,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.6.tgz",
|
||||
"integrity": "sha512-wTzYulosJr/6nFnqGW7FrG3jfUUlEf8UjGA0/pyypJl42ExdVgC6xJgcXQ+V8QFn6niSG2Pb8+MIG1mZr2vczw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
|
||||
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -2958,9 +2958,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.6.tgz",
|
||||
"integrity": "sha512-BLFPYPDO+MNJsiDWbeVzqvYd4NyuRrEYVB5k2N3JfWncuHAy2IVwMAOlVQDFjj+krkWzhY2apvmekMkfQR0CUQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
|
||||
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -2974,9 +2974,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.6.tgz",
|
||||
"integrity": "sha512-OJYkCd5pj/QloBvoEcJ2XiMnlJkRv9idWA/j0ugSuA34gMT6f5b7vOiCQHVRpvStoZUknhl6/UxOXL4OwtdaBw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -2990,9 +2990,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.6.tgz",
|
||||
"integrity": "sha512-S4J2v+8tT3NIO9u2q+S0G5KdvNDjXfAv06OhfOzNDaBn5rw84DGXWndOEB7d5/x852A20sW1M56vhC/tRVbccQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -3006,9 +3006,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.6.tgz",
|
||||
"integrity": "sha512-2eEBDkFlMMNQnkTyPBhQOAyn2qMxyG2eE7GPH2WIDGEpEILcBPI/jdSv4t6xupSP+ot/jkfrCShLAa7+ZUPcJQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -3022,9 +3022,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.6.tgz",
|
||||
"integrity": "sha512-oicJwRlyOoZXVlxmIMaTq7f8pN9QNbdes0q2FXfRsPhfCi8n8JmOZJm5oo1pwDaFbnnD421rVU409M3evFbIqg==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -3038,9 +3038,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.6.tgz",
|
||||
"integrity": "sha512-gQmm8izDTPgs+DCWH22kcDmuUp7NyiJgEl18bcr8irXA5N2m2O+JQIr6f3ct42GOs9c0h8QF3L5SzIxcYAAXXw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -3054,9 +3054,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.6.tgz",
|
||||
"integrity": "sha512-NRfO39AIrzBnixKbjuo2YiYhB6o9d8v/ymU9m/Xk8cyVk+k7XylniXkHwjs4s70wedVffc6bQNbufk5v0xEm0A==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -14068,14 +14068,14 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "16.1.6",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.6.tgz",
|
||||
"integrity": "sha512-hkyRkcu5x/41KoqnROkfTm2pZVbKxvbZRuNvKXLRXxs3VfyO0WhY50TQS40EuKO9SW3rBj/sF3WbVwDACeMZyw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
|
||||
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/env": "16.1.6",
|
||||
"@next/env": "16.1.7",
|
||||
"@swc/helpers": "0.5.15",
|
||||
"baseline-browser-mapping": "^2.8.3",
|
||||
"baseline-browser-mapping": "^2.9.19",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
"postcss": "8.4.31",
|
||||
"styled-jsx": "5.1.6"
|
||||
@@ -14087,14 +14087,14 @@
|
||||
"node": ">=20.9.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "16.1.6",
|
||||
"@next/swc-darwin-x64": "16.1.6",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.6",
|
||||
"@next/swc-linux-arm64-musl": "16.1.6",
|
||||
"@next/swc-linux-x64-gnu": "16.1.6",
|
||||
"@next/swc-linux-x64-musl": "16.1.6",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.6",
|
||||
"@next/swc-win32-x64-msvc": "16.1.6",
|
||||
"@next/swc-darwin-arm64": "16.1.7",
|
||||
"@next/swc-darwin-x64": "16.1.7",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.7",
|
||||
"@next/swc-linux-arm64-musl": "16.1.7",
|
||||
"@next/swc-linux-x64-gnu": "16.1.7",
|
||||
"@next/swc-linux-x64-musl": "16.1.7",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.7",
|
||||
"@next/swc-win32-x64-msvc": "16.1.7",
|
||||
"sharp": "^0.34.4"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
||||
@@ -79,7 +79,7 @@
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"mime": "^4.1.0",
|
||||
"motion": "^12.29.0",
|
||||
"next": "16.1.6",
|
||||
"next": "16.1.7",
|
||||
"next-themes": "^0.4.4",
|
||||
"postcss": "^8.5.6",
|
||||
"posthog-js": "^1.176.0",
|
||||
|
||||
@@ -12,7 +12,7 @@ import { localizeAndPrettify } from "@/lib/time";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
import { useEffect, useState, useMemo } from "react";
|
||||
import { useCallback, useEffect, useRef, useState, useMemo } from "react";
|
||||
import { SvgAlertTriangle } from "@opal/icons";
|
||||
export interface IndexAttemptErrorsModalProps {
|
||||
errors: {
|
||||
@@ -22,93 +22,66 @@ export interface IndexAttemptErrorsModalProps {
|
||||
onClose: () => void;
|
||||
onResolveAll: () => void;
|
||||
isResolvingErrors?: boolean;
|
||||
onPageChange?: (page: number) => void;
|
||||
currentPage?: number;
|
||||
pageSize?: number;
|
||||
}
|
||||
|
||||
const ROW_HEIGHT = 65; // 4rem + 1px for border
|
||||
|
||||
export default function IndexAttemptErrorsModal({
|
||||
errors,
|
||||
onClose,
|
||||
onResolveAll,
|
||||
isResolvingErrors = false,
|
||||
pageSize: propPageSize,
|
||||
}: IndexAttemptErrorsModalProps) {
|
||||
const [calculatedPageSize, setCalculatedPageSize] = useState(10);
|
||||
const observerRef = useRef<ResizeObserver | null>(null);
|
||||
const [pageSize, setPageSize] = useState(10);
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
|
||||
// Reset to page 1 when the error list actually changes
|
||||
useEffect(() => {
|
||||
setCurrentPage(1);
|
||||
}, [errors.items.length, errors.total_items]);
|
||||
const tableContainerRef = useCallback((container: HTMLDivElement | null) => {
|
||||
if (observerRef.current) {
|
||||
observerRef.current.disconnect();
|
||||
observerRef.current = null;
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const calculatePageSize = () => {
|
||||
// Modal height is 75% of viewport height
|
||||
const modalHeight = window.innerHeight * 0.6;
|
||||
if (!container) return;
|
||||
|
||||
// Estimate heights (in pixels):
|
||||
// - Modal header (title + description): ~120px
|
||||
// - Table header: ~40px
|
||||
// - Pagination section: ~80px
|
||||
// - Modal padding: ~64px (32px top + 32px bottom)
|
||||
const fixedHeight = 120 + 40 + 80 + 64;
|
||||
const observer = new ResizeObserver(() => {
|
||||
const thead = container.querySelector("thead");
|
||||
const theadHeight = thead?.getBoundingClientRect().height ?? 0;
|
||||
const availableHeight = container.clientHeight - theadHeight;
|
||||
const newPageSize = Math.max(3, Math.floor(availableHeight / ROW_HEIGHT));
|
||||
setPageSize(newPageSize);
|
||||
});
|
||||
|
||||
// Available height for table rows
|
||||
const availableHeight = modalHeight - fixedHeight;
|
||||
|
||||
// Each table row is approximately 60px (including borders and padding)
|
||||
const rowHeight = 60;
|
||||
|
||||
// Calculate how many rows can fit, with a minimum of 3
|
||||
const rowsPerPage = Math.max(3, Math.floor(availableHeight / rowHeight));
|
||||
|
||||
setCalculatedPageSize((prev) => {
|
||||
// Only update if the new size is significantly different to prevent flickering
|
||||
if (Math.abs(prev - rowsPerPage) > 0) {
|
||||
return rowsPerPage;
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
};
|
||||
|
||||
// Initial calculation
|
||||
calculatePageSize();
|
||||
|
||||
// Debounced resize handler to prevent excessive recalculation
|
||||
let resizeTimeout: NodeJS.Timeout;
|
||||
const debouncedCalculatePageSize = () => {
|
||||
clearTimeout(resizeTimeout);
|
||||
resizeTimeout = setTimeout(calculatePageSize, 100);
|
||||
};
|
||||
|
||||
window.addEventListener("resize", debouncedCalculatePageSize);
|
||||
return () => {
|
||||
window.removeEventListener("resize", debouncedCalculatePageSize);
|
||||
clearTimeout(resizeTimeout);
|
||||
};
|
||||
observer.observe(container);
|
||||
observerRef.current = observer;
|
||||
}, []);
|
||||
|
||||
// Separate effect to reset current page when page size changes
|
||||
// When data changes, reset to page 1.
|
||||
// When page size changes (resize), preserve the user's position by
|
||||
// finding which new page contains the first item they were looking at.
|
||||
const prevPageSizeRef = useRef(pageSize);
|
||||
useEffect(() => {
|
||||
setCurrentPage(1);
|
||||
}, [calculatedPageSize]);
|
||||
if (pageSize !== prevPageSizeRef.current) {
|
||||
setCurrentPage((prev) => {
|
||||
const firstVisibleIndex = (prev - 1) * prevPageSizeRef.current;
|
||||
const newPage = Math.floor(firstVisibleIndex / pageSize) + 1;
|
||||
const totalPages = Math.ceil(errors.items.length / pageSize);
|
||||
return Math.min(newPage, totalPages);
|
||||
});
|
||||
prevPageSizeRef.current = pageSize;
|
||||
} else {
|
||||
setCurrentPage(1);
|
||||
}
|
||||
}, [errors.items.length, pageSize]);
|
||||
|
||||
const pageSize = propPageSize || calculatedPageSize;
|
||||
|
||||
// Memoize pagination calculations to prevent unnecessary recalculations
|
||||
const paginationData = useMemo(() => {
|
||||
const totalPages = Math.ceil(errors.items.length / pageSize);
|
||||
const startIndex = (currentPage - 1) * pageSize;
|
||||
const endIndex = startIndex + pageSize;
|
||||
const currentPageItems = errors.items.slice(startIndex, endIndex);
|
||||
|
||||
return {
|
||||
totalPages,
|
||||
currentPageItems,
|
||||
const currentPageItems = errors.items.slice(
|
||||
startIndex,
|
||||
endIndex,
|
||||
};
|
||||
startIndex + pageSize
|
||||
);
|
||||
return { totalPages, currentPageItems };
|
||||
}, [errors.items, pageSize, currentPage]);
|
||||
|
||||
const hasUnresolvedErrors = useMemo(
|
||||
@@ -137,7 +110,7 @@ export default function IndexAttemptErrorsModal({
|
||||
onClose={onClose}
|
||||
height="fit"
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Modal.Body height="full">
|
||||
{!isResolvingErrors && (
|
||||
<div className="flex flex-col gap-2 flex-shrink-0">
|
||||
<Text as="p">
|
||||
@@ -152,7 +125,10 @@ export default function IndexAttemptErrorsModal({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex-1 overflow-hidden min-h-0">
|
||||
<div
|
||||
ref={tableContainerRef}
|
||||
className="flex-1 w-full overflow-hidden min-h-0"
|
||||
>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
@@ -165,11 +141,11 @@ export default function IndexAttemptErrorsModal({
|
||||
<TableBody>
|
||||
{paginationData.currentPageItems.length > 0 ? (
|
||||
paginationData.currentPageItems.map((error) => (
|
||||
<TableRow key={error.id} className="h-[60px] max-h-[60px]">
|
||||
<TableCell className="h-[60px] align-top">
|
||||
<TableRow key={error.id} className="h-[4rem]">
|
||||
<TableCell>
|
||||
{localizeAndPrettify(error.time_created)}
|
||||
</TableCell>
|
||||
<TableCell className="h-[60px] align-top">
|
||||
<TableCell>
|
||||
{error.document_link ? (
|
||||
<a
|
||||
href={error.document_link}
|
||||
@@ -183,12 +159,12 @@ export default function IndexAttemptErrorsModal({
|
||||
error.document_id || error.entity_id || "Unknown"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="h-[60px] align-top p-0">
|
||||
<div className="h-[60px] overflow-y-auto p-4 whitespace-normal">
|
||||
<TableCell>
|
||||
<div className="flex items-center h-[2rem] overflow-y-auto whitespace-normal">
|
||||
{error.failure_message}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="h-[60px] align-top">
|
||||
<TableCell>
|
||||
<span
|
||||
className={`px-2 py-1 rounded text-xs ${
|
||||
error.is_resolved
|
||||
@@ -202,7 +178,7 @@ export default function IndexAttemptErrorsModal({
|
||||
</TableRow>
|
||||
))
|
||||
) : (
|
||||
<TableRow>
|
||||
<TableRow className="h-[4rem]">
|
||||
<TableCell
|
||||
colSpan={4}
|
||||
className="text-center py-8 text-gray-500"
|
||||
@@ -215,32 +191,24 @@ export default function IndexAttemptErrorsModal({
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
<div className="flex-shrink-0">
|
||||
{paginationData.totalPages > 1 && (
|
||||
<div className="flex-1 flex justify-center mb-2">
|
||||
<PageSelector
|
||||
totalPages={paginationData.totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={handlePageChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex w-full">
|
||||
<div className="flex gap-2 ml-auto">
|
||||
{hasUnresolvedErrors && !isResolvingErrors && (
|
||||
// TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved
|
||||
<Button
|
||||
onClick={onResolveAll}
|
||||
className="ml-4 whitespace-nowrap"
|
||||
>
|
||||
Resolve All
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
{paginationData.totalPages > 1 && (
|
||||
<div className="flex w-full justify-center">
|
||||
<PageSelector
|
||||
totalPages={paginationData.totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={handlePageChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
{hasUnresolvedErrors && !isResolvingErrors && (
|
||||
// TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved
|
||||
<Button onClick={onResolveAll} className="ml-4 whitespace-nowrap">
|
||||
Resolve All
|
||||
</Button>
|
||||
)}
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
@@ -18,7 +18,7 @@ import { PageSelector } from "@/components/PageSelector";
|
||||
import { localizeAndPrettify } from "@/lib/time";
|
||||
import { getDocsProcessedPerMinute } from "@/lib/indexAttempt";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import ExceptionTraceModal from "@/sections/modals/PreviewModal/ExceptionTraceModal";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import { SvgClock } from "@opal/icons";
|
||||
export interface IndexingAttemptsTableProps {
|
||||
@@ -98,7 +98,14 @@ export function IndexAttemptsTable({
|
||||
isReindexInProgress ? "are being" : "were"
|
||||
} synced into the system.`;
|
||||
return (
|
||||
<TableRow key={indexAttempt.id}>
|
||||
<TableRow
|
||||
key={indexAttempt.id}
|
||||
className={
|
||||
indexAttempt.full_exception_trace
|
||||
? "hover:bg-accent-background cursor-pointer relative select-none"
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
<TableCell>
|
||||
{indexAttempt.time_started
|
||||
? localizeAndPrettify(indexAttempt.time_started)
|
||||
@@ -146,46 +153,43 @@ export function IndexAttemptsTable({
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div>
|
||||
{indexAttempt.status === "success" && (
|
||||
{indexAttempt.status === "success" && (
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
{"-"}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{indexAttempt.status === "failed" &&
|
||||
indexAttempt.error_msg && (
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
{"-"}
|
||||
{indexAttempt.error_msg}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{indexAttempt.status === "failed" &&
|
||||
indexAttempt.error_msg && (
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
{indexAttempt.error_msg}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{indexAttempt.full_exception_trace && (
|
||||
<div
|
||||
onClick={() => {
|
||||
setIndexAttemptTracePopupId(indexAttempt.id);
|
||||
}}
|
||||
className="mt-2 text-link cursor-pointer select-none"
|
||||
>
|
||||
View Full Trace
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<td className="w-0 p-0">
|
||||
{indexAttempt.full_exception_trace && (
|
||||
<button
|
||||
type="button"
|
||||
aria-label="View full trace"
|
||||
onClick={() =>
|
||||
setIndexAttemptTracePopupId(indexAttempt.id)
|
||||
}
|
||||
className="absolute w-full h-full left-0 top-0"
|
||||
/>
|
||||
)}
|
||||
</td>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
{totalPages > 1 && (
|
||||
<div className="mt-3 flex">
|
||||
<div className="mx-auto">
|
||||
<PageSelector
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={onPageChange}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-1 justify-center pt-3">
|
||||
<PageSelector
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={onPageChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -21,10 +21,13 @@ export const submitGoogleSite = async (
|
||||
formData.append("files", file);
|
||||
});
|
||||
|
||||
const response = await fetch("/api/manage/admin/connector/file/upload", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
});
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/file/upload?unzip=false",
|
||||
{
|
||||
method: "POST",
|
||||
body: formData,
|
||||
}
|
||||
);
|
||||
const responseJson = await response.json();
|
||||
if (!response.ok) {
|
||||
toast.error(`Unable to upload files - ${responseJson.detail}`);
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
} from "@/lib/types";
|
||||
import type { Route } from "next";
|
||||
import { useRouter } from "next/navigation";
|
||||
import Truncated from "@/refresh-components/texts/Truncated";
|
||||
import {
|
||||
FiChevronDown,
|
||||
FiChevronRight,
|
||||
@@ -165,9 +166,7 @@ function ConnectorRow({
|
||||
onClick={handleRowClick}
|
||||
>
|
||||
<TableCell className="">
|
||||
<p className="max-w-[200px] xl:max-w-[400px] inline-block ellipsis truncate">
|
||||
{ccPairsIndexingStatus.name}
|
||||
</p>
|
||||
<Truncated>{ccPairsIndexingStatus.name}</Truncated>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{timeAgo(ccPairsIndexingStatus?.last_success) || "-"}
|
||||
@@ -246,9 +245,7 @@ function FederatedConnectorRow({
|
||||
onClick={handleRowClick}
|
||||
>
|
||||
<TableCell className="">
|
||||
<p className="max-w-[200px] xl:max-w-[400px] inline-block ellipsis truncate">
|
||||
{federatedConnector.name}
|
||||
</p>
|
||||
<Truncated>{federatedConnector.name}</Truncated>
|
||||
</TableCell>
|
||||
<TableCell>N/A</TableCell>
|
||||
<TableCell>
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
import { useState } from "react";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgAlertTriangle, SvgCheck, SvgCopy } from "@opal/icons";
|
||||
|
||||
interface ExceptionTraceModalProps {
|
||||
onOutsideClick: () => void;
|
||||
exceptionTrace: string;
|
||||
}
|
||||
|
||||
export default function ExceptionTraceModal({
|
||||
onOutsideClick,
|
||||
exceptionTrace,
|
||||
}: ExceptionTraceModalProps) {
|
||||
const [copyClicked, setCopyClicked] = useState(false);
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onOutsideClick}>
|
||||
<Modal.Content width="lg" height="full">
|
||||
<Modal.Header
|
||||
icon={SvgAlertTriangle}
|
||||
title="Full Exception Trace"
|
||||
onClose={onOutsideClick}
|
||||
height="fit"
|
||||
/>
|
||||
<Modal.Body>
|
||||
<div className="mb-6">
|
||||
{!copyClicked ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
navigator.clipboard.writeText(exceptionTrace!);
|
||||
setCopyClicked(true);
|
||||
setTimeout(() => setCopyClicked(false), 2000);
|
||||
}}
|
||||
className="flex w-fit items-center hover:bg-accent-background p-2 border-border border rounded"
|
||||
>
|
||||
<Text>Copy full trace</Text>
|
||||
<SvgCopy className="stroke-text-04 ml-2 h-4 w-4 flex flex-shrink-0" />
|
||||
</button>
|
||||
) : (
|
||||
<div className="flex w-fit items-center hover:bg-accent-background p-2 border-border border rounded cursor-default">
|
||||
<Text>Copied to clipboard</Text>
|
||||
<SvgCheck className="stroke-text-04 my-auto ml-2 h-4 w-4 flex flex-shrink-0" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="whitespace-pre-wrap">{exceptionTrace}</div>
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -661,7 +661,7 @@ export default function AgentEditorPage({
|
||||
// Sharing
|
||||
shared_user_ids: existingAgent?.users?.map((user) => user.id) ?? [],
|
||||
shared_group_ids: existingAgent?.groups ?? [],
|
||||
is_public: existingAgent?.is_public ?? true,
|
||||
is_public: existingAgent?.is_public ?? false,
|
||||
label_ids: existingAgent?.labels?.map((l) => l.id) ?? [],
|
||||
featured: existingAgent?.featured ?? false,
|
||||
};
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
SvgKey,
|
||||
} from "@opal/icons";
|
||||
import { Disabled } from "@opal/core";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
@@ -78,18 +79,17 @@ export default function UserRowActions({
|
||||
return (
|
||||
<>
|
||||
{user.id && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
<LineItem
|
||||
icon={SvgUsers}
|
||||
onClick={() => openModal(Modal.EDIT_GROUPS)}
|
||||
>
|
||||
Groups & Roles
|
||||
</Button>
|
||||
</LineItem>
|
||||
)}
|
||||
<Disabled disabled>
|
||||
<Button prominence="tertiary" variant="danger" icon={SvgUserX}>
|
||||
<LineItem danger icon={SvgUserX}>
|
||||
Deactivate User
|
||||
</Button>
|
||||
</LineItem>
|
||||
</Disabled>
|
||||
<Separator paddingXRem={0.5} />
|
||||
<Text as="p" secondaryBody text03 className="px-3 py-1">
|
||||
@@ -102,20 +102,18 @@ export default function UserRowActions({
|
||||
switch (user.status) {
|
||||
case UserStatus.INVITED:
|
||||
return (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
variant="danger"
|
||||
<LineItem
|
||||
danger
|
||||
icon={SvgXCircle}
|
||||
onClick={() => openModal(Modal.CANCEL_INVITE)}
|
||||
>
|
||||
Cancel Invite
|
||||
</Button>
|
||||
</LineItem>
|
||||
);
|
||||
|
||||
case UserStatus.REQUESTED:
|
||||
return (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
<LineItem
|
||||
icon={SvgUserCheck}
|
||||
onClick={() => {
|
||||
setPopoverOpen(false);
|
||||
@@ -133,37 +131,34 @@ export default function UserRowActions({
|
||||
}}
|
||||
>
|
||||
Approve
|
||||
</Button>
|
||||
</LineItem>
|
||||
);
|
||||
|
||||
case UserStatus.ACTIVE:
|
||||
return (
|
||||
<>
|
||||
{user.id && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
<LineItem
|
||||
icon={SvgUsers}
|
||||
onClick={() => openModal(Modal.EDIT_GROUPS)}
|
||||
>
|
||||
Groups & Roles
|
||||
</Button>
|
||||
</LineItem>
|
||||
)}
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
<LineItem
|
||||
icon={SvgKey}
|
||||
onClick={() => openModal(Modal.RESET_PASSWORD)}
|
||||
>
|
||||
Reset Password
|
||||
</Button>
|
||||
</LineItem>
|
||||
<Separator paddingXRem={0.5} />
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
variant="danger"
|
||||
<LineItem
|
||||
danger
|
||||
icon={SvgUserX}
|
||||
onClick={() => openModal(Modal.DEACTIVATE)}
|
||||
>
|
||||
Deactivate User
|
||||
</Button>
|
||||
</LineItem>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -171,38 +166,34 @@ export default function UserRowActions({
|
||||
return (
|
||||
<>
|
||||
{user.id && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
<LineItem
|
||||
icon={SvgUsers}
|
||||
onClick={() => openModal(Modal.EDIT_GROUPS)}
|
||||
>
|
||||
Groups & Roles
|
||||
</Button>
|
||||
</LineItem>
|
||||
)}
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
<LineItem
|
||||
icon={SvgKey}
|
||||
onClick={() => openModal(Modal.RESET_PASSWORD)}
|
||||
>
|
||||
Reset Password
|
||||
</Button>
|
||||
</LineItem>
|
||||
<Separator paddingXRem={0.5} />
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
<LineItem
|
||||
icon={SvgUserPlus}
|
||||
onClick={() => openModal(Modal.ACTIVATE)}
|
||||
>
|
||||
Activate User
|
||||
</Button>
|
||||
</LineItem>
|
||||
<Separator paddingXRem={0.5} />
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
variant="danger"
|
||||
<LineItem
|
||||
danger
|
||||
icon={SvgUserX}
|
||||
onClick={() => openModal(Modal.DELETE)}
|
||||
>
|
||||
Delete User
|
||||
</Button>
|
||||
</LineItem>
|
||||
</>
|
||||
);
|
||||
|
||||
|
||||
39
web/src/sections/modals/PreviewModal/ExceptionTraceModal.tsx
Normal file
39
web/src/sections/modals/PreviewModal/ExceptionTraceModal.tsx
Normal file
@@ -0,0 +1,39 @@
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { SvgAlertTriangle } from "@opal/icons";
|
||||
import { CodePreview } from "@/sections/modals/PreviewModal/variants/CodePreview";
|
||||
import { CopyButton } from "@/sections/modals/PreviewModal/variants/shared";
|
||||
import FloatingFooter from "@/sections/modals/PreviewModal/FloatingFooter";
|
||||
|
||||
interface ExceptionTraceModalProps {
|
||||
onOutsideClick: () => void;
|
||||
exceptionTrace: string;
|
||||
language?: string;
|
||||
}
|
||||
|
||||
export default function ExceptionTraceModal({
|
||||
onOutsideClick,
|
||||
exceptionTrace,
|
||||
language = "python",
|
||||
}: ExceptionTraceModalProps) {
|
||||
return (
|
||||
<Modal open onOpenChange={onOutsideClick}>
|
||||
<Modal.Content width="lg" height="full">
|
||||
<Modal.Header
|
||||
icon={SvgAlertTriangle}
|
||||
title="Full Exception Trace"
|
||||
onClose={onOutsideClick}
|
||||
height="fit"
|
||||
/>
|
||||
|
||||
<div className="flex flex-col flex-1 min-h-0 overflow-hidden w-full bg-background-tint-01">
|
||||
<CodePreview content={exceptionTrace} language={language} normalize />
|
||||
</div>
|
||||
|
||||
<FloatingFooter
|
||||
right={<CopyButton getText={() => exceptionTrace} />}
|
||||
codeBackground
|
||||
/>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
39
web/src/sections/modals/PreviewModal/FloatingFooter.tsx
Normal file
39
web/src/sections/modals/PreviewModal/FloatingFooter.tsx
Normal file
@@ -0,0 +1,39 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ReactNode } from "react";
|
||||
|
||||
interface FloatingFooterProps {
|
||||
left?: ReactNode;
|
||||
right?: ReactNode;
|
||||
codeBackground?: boolean;
|
||||
}
|
||||
|
||||
export default function FloatingFooter({
|
||||
left,
|
||||
right,
|
||||
codeBackground,
|
||||
}: FloatingFooterProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute bottom-0 left-0 right-0",
|
||||
"flex items-center justify-between",
|
||||
"p-4 pointer-events-none w-full"
|
||||
)}
|
||||
style={{
|
||||
background: `linear-gradient(to top, var(--background-${
|
||||
codeBackground ? "code-01" : "tint-01"
|
||||
}) 40%, transparent)`,
|
||||
}}
|
||||
>
|
||||
{/* Left slot */}
|
||||
<div className="pointer-events-auto">{left}</div>
|
||||
|
||||
{/* Right slot */}
|
||||
{right ? (
|
||||
<div className="pointer-events-auto rounded-12 bg-background-tint-00 p-1 shadow-lg">
|
||||
{right}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -5,8 +5,8 @@ import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import FloatingFooter from "@/sections/modals/PreviewModal/FloatingFooter";
|
||||
import mime from "mime";
|
||||
import {
|
||||
getCodeLanguage,
|
||||
@@ -189,30 +189,12 @@ export default function PreviewModal({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Floating footer */}
|
||||
{!isLoading && !loadError && (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute bottom-0 left-0 right-0",
|
||||
"flex items-center justify-between",
|
||||
"p-4 pointer-events-none w-full"
|
||||
)}
|
||||
style={{
|
||||
background: `linear-gradient(to top, var(--background-${
|
||||
variant.codeBackground ? "code-01" : "tint-01"
|
||||
}) 40%, transparent)`,
|
||||
}}
|
||||
>
|
||||
{/* Left slot */}
|
||||
<div className="pointer-events-auto">
|
||||
{variant.renderFooterLeft(ctx)}
|
||||
</div>
|
||||
|
||||
{/* Right slot */}
|
||||
<div className="pointer-events-auto rounded-12 bg-background-tint-00 p-1 shadow-lg">
|
||||
{variant.renderFooterRight(ctx)}
|
||||
</div>
|
||||
</div>
|
||||
<FloatingFooter
|
||||
left={variant.renderFooterLeft(ctx)}
|
||||
right={variant.renderFooterRight(ctx)}
|
||||
codeBackground={variant.codeBackground}
|
||||
/>
|
||||
)}
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
|
||||
83
web/src/sections/modals/ShareAgentModal.test.tsx
Normal file
83
web/src/sections/modals/ShareAgentModal.test.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import React, { useEffect } from "react";
|
||||
import { render, screen, waitFor } from "@tests/setup/test-utils";
|
||||
import ShareAgentModal, { ShareAgentModalProps } from "./ShareAgentModal";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
|
||||
jest.mock("@/hooks/useShareableUsers", () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(() => ({ data: [] })),
|
||||
}));
|
||||
|
||||
jest.mock("@/hooks/useShareableGroups", () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(() => ({ data: [] })),
|
||||
}));
|
||||
|
||||
jest.mock("@/hooks/useAgents", () => ({
|
||||
useAgent: jest.fn(() => ({ agent: null })),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/hooks", () => ({
|
||||
useLabels: jest.fn(() => ({
|
||||
labels: [],
|
||||
createLabel: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
function ModalHarness(props: ShareAgentModalProps) {
|
||||
const modal = useCreateModal();
|
||||
|
||||
useEffect(() => {
|
||||
modal.toggle(true);
|
||||
}, [modal]);
|
||||
|
||||
return (
|
||||
<modal.Provider>
|
||||
<ShareAgentModal {...props} />
|
||||
</modal.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
function renderShareAgentModal(overrides: Partial<ShareAgentModalProps> = {}) {
|
||||
const props: ShareAgentModalProps = {
|
||||
userIds: [],
|
||||
groupIds: [],
|
||||
isPublic: false,
|
||||
isFeatured: false,
|
||||
labelIds: [],
|
||||
...overrides,
|
||||
};
|
||||
|
||||
return render(<ModalHarness {...props} />);
|
||||
}
|
||||
|
||||
describe("ShareAgentModal", () => {
|
||||
it("defaults to Users & Groups when the agent is private", async () => {
|
||||
renderShareAgentModal({ isPublic: false });
|
||||
|
||||
await waitFor(() =>
|
||||
expect(
|
||||
screen.getByRole("tab", { name: "Users & Groups" })
|
||||
).toHaveAttribute("data-state", "active")
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByRole("tab", { name: "Your Organization" })
|
||||
).toHaveAttribute("data-state", "inactive");
|
||||
});
|
||||
|
||||
it("defaults to Your Organization when the agent is public", async () => {
|
||||
renderShareAgentModal({ isPublic: true });
|
||||
|
||||
await waitFor(() =>
|
||||
expect(
|
||||
screen.getByRole("tab", { name: "Your Organization" })
|
||||
).toHaveAttribute("data-state", "active")
|
||||
);
|
||||
|
||||
expect(screen.getByRole("tab", { name: "Users & Groups" })).toHaveAttribute(
|
||||
"data-state",
|
||||
"inactive"
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -12,6 +12,7 @@ import { cn, noProp } from "@/lib/utils";
|
||||
import { DRAG_TYPES } from "./constants";
|
||||
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import Truncated from "@/refresh-components/texts/Truncated";
|
||||
import { Button } from "@opal/components";
|
||||
import ButtonRenaming from "@/refresh-components/buttons/ButtonRenaming";
|
||||
import type { IconProps } from "@opal/types";
|
||||
@@ -181,7 +182,7 @@ const ProjectFolderButton = memo(({ project }: ProjectFolderButtonProps) => {
|
||||
onClose={() => setIsEditing(false)}
|
||||
/>
|
||||
) : (
|
||||
project.name
|
||||
<Truncated>{project.name}</Truncated>
|
||||
)}
|
||||
</SidebarTab>
|
||||
</Popover.Anchor>
|
||||
|
||||
@@ -58,7 +58,7 @@ async function mockCodeInterpreterApi(
|
||||
*/
|
||||
function getDisconnectIconButton(page: Page) {
|
||||
return page
|
||||
.locator("button:has(.opal-button):not(:has(.opal-button-label))")
|
||||
.locator("button:has(.interactive-foreground-icon):not(:has(span))")
|
||||
.first();
|
||||
}
|
||||
|
||||
|
||||
@@ -169,7 +169,9 @@ test.describe("Project Files visual regression", () => {
|
||||
.first();
|
||||
await expect(iconWrapper).toBeVisible();
|
||||
|
||||
await expectElementScreenshot(filesSection, {
|
||||
const container = page.locator("[data-main-container]");
|
||||
await expect(container).toBeVisible();
|
||||
await expectElementScreenshot(container, {
|
||||
name: "project-files-long-underscore-filename",
|
||||
});
|
||||
|
||||
|
||||
@@ -1860,6 +1860,9 @@ test.describe("MCP OAuth flows", () => {
|
||||
toolName: TOOL_NAMES.admin,
|
||||
logStep,
|
||||
});
|
||||
const createdAgent = await adminApiClient.getAssistant(agentId);
|
||||
expect(createdAgent.is_public).toBe(false);
|
||||
logStep("Verified newly created agent is private by default");
|
||||
const adminToolId = await fetchMcpToolIdByName(
|
||||
page,
|
||||
serverId,
|
||||
@@ -1899,6 +1902,13 @@ test.describe("MCP OAuth flows", () => {
|
||||
).toBeVisible({ timeout: 15000 });
|
||||
logStep("Verified MCP server card is still visible on actions page");
|
||||
|
||||
await adminApiClient.updateAgentSharing(agentId, {
|
||||
isPublic: true,
|
||||
userIds: createdAgent.users.map((user) => user.id),
|
||||
groupIds: createdAgent.groups,
|
||||
});
|
||||
logStep("Published agent explicitly for end-user MCP flow");
|
||||
|
||||
adminArtifacts = {
|
||||
serverId,
|
||||
serverName,
|
||||
|
||||
@@ -709,6 +709,9 @@ export class OnyxApiClient {
|
||||
|
||||
async getAssistant(agentId: number): Promise<{
|
||||
id: number;
|
||||
is_public: boolean;
|
||||
users: Array<{ id: string }>;
|
||||
groups: number[];
|
||||
tools: Array<{ id: number; mcp_server_id?: number | null }>;
|
||||
}> {
|
||||
const response = await this.get(`/persona/${agentId}`);
|
||||
@@ -718,6 +721,37 @@ export class OnyxApiClient {
|
||||
);
|
||||
}
|
||||
|
||||
async updateAgentSharing(
|
||||
agentId: number,
|
||||
options: {
|
||||
userIds?: string[];
|
||||
groupIds?: number[];
|
||||
isPublic?: boolean;
|
||||
labelIds?: number[];
|
||||
}
|
||||
): Promise<void> {
|
||||
const response = await this.request.patch(
|
||||
`${this.baseUrl}/persona/${agentId}/share`,
|
||||
{
|
||||
data: {
|
||||
user_ids: options.userIds,
|
||||
group_ids: options.groupIds,
|
||||
is_public: options.isPublic,
|
||||
label_ids: options.labelIds,
|
||||
},
|
||||
}
|
||||
);
|
||||
await this.handleResponse(
|
||||
response,
|
||||
`Failed to update sharing for assistant ${agentId}`
|
||||
);
|
||||
this.log(
|
||||
`Updated assistant sharing: ${agentId} (is_public=${String(
|
||||
options.isPublic
|
||||
)})`
|
||||
);
|
||||
}
|
||||
|
||||
async listMcpServers(): Promise<any[]> {
|
||||
const response = await this.get(`/admin/mcp/servers`);
|
||||
const data = await this.handleResponse<{ mcp_servers: any[] }>(
|
||||
|
||||
Reference in New Issue
Block a user