mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-17 23:46:47 +00:00
Compare commits
9 Commits
jamison/ti
...
v3.2.0-clo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ffd7fbb56 | ||
|
|
f9e88e3c72 | ||
|
|
97efdbbbc3 | ||
|
|
b91a3aed53 | ||
|
|
51480e1099 | ||
|
|
70efbef95e | ||
|
|
f3936e2669 | ||
|
|
c933c71b59 | ||
|
|
e0d9e109b5 |
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -39,6 +39,8 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
|
||||
env:
|
||||
SKIP: ty
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
@@ -68,6 +68,7 @@ repos:
|
||||
pass_filenames: true
|
||||
files: ^backend/(?!\.venv/|scripts/).*\.py$
|
||||
- id: uv-run
|
||||
alias: ty
|
||||
name: ty
|
||||
args: ["ty", "check"]
|
||||
pass_filenames: true
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
@@ -16,9 +18,56 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_tenant_work_gating import cleanup_expired
|
||||
from onyx.redis.redis_tenant_work_gating import get_active_tenants
|
||||
from onyx.redis.redis_tenant_work_gating import observe_active_set_size
|
||||
from onyx.redis.redis_tenant_work_gating import record_full_fanout_cycle
|
||||
from onyx.redis.redis_tenant_work_gating import record_gate_decision
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
|
||||
_FULL_FANOUT_TIMESTAMP_KEY_PREFIX = "tenant_work_gating_last_full_fanout_ms"
|
||||
|
||||
|
||||
def _should_bypass_gate_for_full_fanout(
|
||||
redis_client: Redis, task_name: str, interval_seconds: int
|
||||
) -> bool:
|
||||
"""True if at least `interval_seconds` have elapsed since the last
|
||||
full-fanout bypass for this task. On True, updates the stored timestamp
|
||||
atomically-enough (it's a best-effort counter, not a lock)."""
|
||||
key = f"{_FULL_FANOUT_TIMESTAMP_KEY_PREFIX}:{task_name}"
|
||||
now_ms = int(time.time() * 1000)
|
||||
threshold_ms = now_ms - (interval_seconds * 1000)
|
||||
|
||||
try:
|
||||
raw = cast(bytes | None, redis_client.get(key))
|
||||
except Exception:
|
||||
task_logger.exception(f"full-fanout timestamp read failed: task={task_name}")
|
||||
# Fail open: treat as "interval elapsed" so we don't skip every
|
||||
# tenant during a Redis hiccup.
|
||||
return True
|
||||
|
||||
if raw is None:
|
||||
# First invocation — bypass so the set seeds cleanly.
|
||||
elapsed = True
|
||||
else:
|
||||
try:
|
||||
last_ms = int(raw.decode())
|
||||
elapsed = last_ms <= threshold_ms
|
||||
except ValueError:
|
||||
elapsed = True
|
||||
|
||||
if elapsed:
|
||||
try:
|
||||
redis_client.set(key, str(now_ms))
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"full-fanout timestamp write failed: task={task_name}"
|
||||
)
|
||||
return elapsed
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
@@ -32,6 +81,7 @@ def cloud_beat_task_generator(
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
skip_gated: bool = True,
|
||||
work_gated: bool = False,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
@@ -51,8 +101,56 @@ def cloud_beat_task_generator(
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
num_skipped_gated = 0
|
||||
num_would_skip_work_gate = 0
|
||||
num_skipped_work_gate = 0
|
||||
|
||||
# Tenant-work-gating read path. Resolve once per invocation.
|
||||
gate_enabled = False
|
||||
gate_enforce = False
|
||||
full_fanout_cycle = False
|
||||
active_tenants: set[str] | None = None
|
||||
|
||||
try:
|
||||
# Gating setup is inside the try block so any exception still
|
||||
# reaches the finally that releases the beat lock.
|
||||
if work_gated:
|
||||
try:
|
||||
gate_enabled = OnyxRuntime.get_tenant_work_gating_enabled()
|
||||
gate_enforce = OnyxRuntime.get_tenant_work_gating_enforce()
|
||||
except Exception:
|
||||
task_logger.exception("tenant work gating: runtime flag read failed")
|
||||
gate_enabled = False
|
||||
|
||||
if gate_enabled:
|
||||
redis_failed = False
|
||||
interval_s = (
|
||||
OnyxRuntime.get_tenant_work_gating_full_fanout_interval_seconds()
|
||||
)
|
||||
full_fanout_cycle = _should_bypass_gate_for_full_fanout(
|
||||
redis_client, task_name, interval_s
|
||||
)
|
||||
if full_fanout_cycle:
|
||||
record_full_fanout_cycle(task_name)
|
||||
try:
|
||||
ttl_s = OnyxRuntime.get_tenant_work_gating_ttl_seconds()
|
||||
cleanup_expired(ttl_s)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"tenant work gating: cleanup_expired failed"
|
||||
)
|
||||
else:
|
||||
ttl_s = OnyxRuntime.get_tenant_work_gating_ttl_seconds()
|
||||
active_tenants = get_active_tenants(ttl_s)
|
||||
if active_tenants is None:
|
||||
full_fanout_cycle = True
|
||||
record_full_fanout_cycle(task_name)
|
||||
redis_failed = True
|
||||
|
||||
# Only refresh the gauge when Redis is known-reachable —
|
||||
# skip the ZCARD if we just failed open due to a Redis error.
|
||||
if not redis_failed:
|
||||
observe_active_set_size()
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# Per-task control over whether gated tenants are included. Most periodic tasks
|
||||
@@ -76,6 +174,21 @@ def cloud_beat_task_generator(
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
# Tenant work gate: if the feature is on, check membership. Skip
|
||||
# unmarked tenants when enforce=True AND we're not in a full-
|
||||
# fanout cycle. Always log/emit the shadow counter.
|
||||
if work_gated and gate_enabled and not full_fanout_cycle:
|
||||
would_skip = (
|
||||
active_tenants is not None and tenant_id not in active_tenants
|
||||
)
|
||||
if would_skip:
|
||||
num_would_skip_work_gate += 1
|
||||
if gate_enforce:
|
||||
num_skipped_work_gate += 1
|
||||
record_gate_decision(task_name, skipped=True)
|
||||
continue
|
||||
record_gate_decision(task_name, skipped=False)
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
@@ -109,6 +222,12 @@ def cloud_beat_task_generator(
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_skipped_gated={num_skipped_gated} "
|
||||
f"num_would_skip_work_gate={num_would_skip_work_gate} "
|
||||
f"num_skipped_work_gate={num_skipped_work_gate} "
|
||||
f"full_fanout_cycle={full_fanout_cycle} "
|
||||
f"work_gated={work_gated} "
|
||||
f"gate_enabled={gate_enabled} "
|
||||
f"gate_enforce={gate_enforce} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFI
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
@@ -531,23 +532,26 @@ def reset_tenant_id(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(
|
||||
sender: Any, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
def wait_for_document_index_or_shutdown() -> None:
|
||||
"""
|
||||
Waits for all configured document indices to become ready subject to a
|
||||
timeout.
|
||||
|
||||
Raises WorkerShutdown if the timeout is reached.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
logger.info(
|
||||
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
|
||||
)
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
if not ONYX_DISABLE_VESPA:
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = (
|
||||
"[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
|
||||
@@ -105,7 +105,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -111,7 +111,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -97,7 +97,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -118,7 +118,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -124,7 +124,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.configs.app_configs import DISABLE_OPENSEARCH_MIGRATION_TASK
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
@@ -67,6 +68,7 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -100,6 +102,7 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Gated tenants may still have connectors awaiting deletion.
|
||||
"skip_gated": False,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -109,6 +112,7 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -118,6 +122,7 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -155,6 +160,7 @@ beat_task_templates: list[dict] = [
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -179,6 +185,7 @@ if ENTERPRISE_EDITION_ENABLED:
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -188,6 +195,7 @@ if ENTERPRISE_EDITION_ENABLED:
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
]
|
||||
@@ -227,7 +235,11 @@ if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
)
|
||||
|
||||
# Add OpenSearch migration task if enabled.
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX and not DISABLE_OPENSEARCH_MIGRATION_TASK:
|
||||
if (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and not DISABLE_OPENSEARCH_MIGRATION_TASK
|
||||
and not ONYX_DISABLE_VESPA
|
||||
):
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "migrate-chunks-from-vespa-to-opensearch",
|
||||
@@ -280,7 +292,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated"]
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated", "work_gated"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
@@ -373,12 +385,14 @@ if not MULTI_TENANT:
|
||||
]
|
||||
)
|
||||
|
||||
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
|
||||
# it before extending the self-hosted schedule so it doesn't leak into apply_async
|
||||
# as an unrecognised option on every fired task message.
|
||||
# `skip_gated` and `work_gated` are cloud-only hints consumed by
|
||||
# `cloud_beat_task_generator`. Strip them before extending the self-hosted
|
||||
# schedule so they don't leak into apply_async as unrecognised options on
|
||||
# every fired task message.
|
||||
for _template in beat_task_templates:
|
||||
_self_hosted_template = copy.deepcopy(_template)
|
||||
_self_hosted_template["options"].pop("skip_gated", None)
|
||||
_self_hosted_template["options"].pop("work_gated", None)
|
||||
tasks_to_schedule.append(_self_hosted_template)
|
||||
|
||||
|
||||
|
||||
@@ -327,6 +327,7 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
DISABLE_OPENSEARCH_MIGRATION_TASK = (
|
||||
os.environ.get("DISABLE_OPENSEARCH_MIGRATION_TASK", "").lower() == "true"
|
||||
)
|
||||
ONYX_DISABLE_VESPA = os.environ.get("ONYX_DISABLE_VESPA", "").lower() == "true"
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
|
||||
@@ -379,10 +379,20 @@ def _download_and_extract_sections_basic(
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
or is_tabular_file(file_name)
|
||||
):
|
||||
# Google Drive doesn't enforce file extensions, so the filename may not
|
||||
# end in .xlsx even when the mime type says it's one. Synthesize the
|
||||
# extension so tabular_file_to_sections dispatches correctly.
|
||||
tabular_file_name = file_name
|
||||
if (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
and not is_tabular_file(file_name)
|
||||
):
|
||||
tabular_file_name = f"{file_name}.xlsx"
|
||||
return list(
|
||||
tabular_file_to_sections(
|
||||
io.BytesIO(response_call()),
|
||||
file_name=file_name,
|
||||
file_name=tabular_file_name,
|
||||
link=link,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1958,8 +1958,7 @@ class SharepointConnector(
|
||||
self._graph_client = GraphClient(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
return None
|
||||
|
||||
def _get_drive_names_for_site(self, site_url: str) -> list[str]:
|
||||
|
||||
@@ -244,13 +244,21 @@ def fetch_latest_index_attempts_by_status(
|
||||
return query.all()
|
||||
|
||||
|
||||
_INTERNAL_ONLY_SOURCES = {
|
||||
# Used by the ingestion API, not a user-created connector.
|
||||
DocumentSource.INGESTION_API,
|
||||
# Backs the user library / build feature, not a connector users filter by.
|
||||
DocumentSource.CRAFT_FILE,
|
||||
}
|
||||
|
||||
|
||||
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
||||
distinct_sources = db_session.query(Connector.source).distinct().all()
|
||||
|
||||
sources = [
|
||||
source[0]
|
||||
for source in distinct_sources
|
||||
if source[0] != DocumentSource.INGESTION_API
|
||||
if source[0] not in _INTERNAL_ONLY_SOURCES
|
||||
]
|
||||
|
||||
return sources
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import OpenSearchDocumentMigrationRecord
|
||||
@@ -412,7 +413,11 @@ def get_opensearch_retrieval_state(
|
||||
|
||||
If the tenant migration record is not found, defaults to
|
||||
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX.
|
||||
|
||||
If ONYX_DISABLE_VESPA is True, always returns True.
|
||||
"""
|
||||
if ONYX_DISABLE_VESPA:
|
||||
return True
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
return ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.opensearch_migration import get_opensearch_retrieval_state
|
||||
from onyx.document_index.disabled import DisabledDocumentIndex
|
||||
@@ -48,6 +49,11 @@ def get_default_document_index(
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if ONYX_DISABLE_VESPA:
|
||||
if not opensearch_retrieval_enabled:
|
||||
raise ValueError(
|
||||
"BUG: ONYX_DISABLE_VESPA is set but opensearch_retrieval_enabled is not set."
|
||||
)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
@@ -119,21 +125,32 @@ def get_all_document_indices(
|
||||
)
|
||||
]
|
||||
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name if secondary_search_settings else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
result: list[DocumentIndex] = []
|
||||
|
||||
if ONYX_DISABLE_VESPA:
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
raise ValueError(
|
||||
"ONYX_DISABLE_VESPA is set but ENABLE_OPENSEARCH_INDEXING_FOR_ONYX is not set."
|
||||
)
|
||||
else:
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
result.append(vespa_document_index)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
@@ -169,7 +186,6 @@ def get_all_document_indices(
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
result: list[DocumentIndex] = [vespa_document_index]
|
||||
if opensearch_document_index:
|
||||
result.append(opensearch_document_index)
|
||||
|
||||
return result
|
||||
|
||||
@@ -48,6 +48,7 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
"File contains no valid workbook part",
|
||||
"Unable to read workbook: could not read stylesheet from None",
|
||||
"Colors must be aRGB hex values",
|
||||
"Max value is",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -19,9 +19,14 @@ from onyx.configs.app_configs import MCP_SERVER_CORS_ORIGINS
|
||||
from onyx.mcp_server.auth import OnyxTokenVerifier
|
||||
from onyx.mcp_server.utils import shutdown_http_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Initialize EE flag at module import so it's set regardless of the entry point
|
||||
# (python -m onyx.mcp_server_main, uvicorn onyx.mcp_server.api:mcp_app, etc.).
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
logger.info("Creating Onyx MCP Server...")
|
||||
|
||||
mcp_server = FastMCP(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Resource registrations for the Onyx MCP server."""
|
||||
|
||||
# Import resource modules so decorators execute when the package loads.
|
||||
from onyx.mcp_server.resources import document_sets # noqa: F401
|
||||
from onyx.mcp_server.resources import indexed_sources # noqa: F401
|
||||
|
||||
41
backend/onyx/mcp_server/resources/document_sets.py
Normal file
41
backend/onyx/mcp_server/resources/document_sets.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Resource exposing document sets available to the current user."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_accessible_document_sets
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@mcp_server.resource(
|
||||
"resource://document_sets",
|
||||
name="document_sets",
|
||||
description=(
|
||||
"Enumerate the Document Sets accessible to the current user. Use the "
|
||||
"returned `name` values with the `document_set_names` filter of the "
|
||||
"`search_indexed_documents` tool to scope searches to a specific set."
|
||||
),
|
||||
mime_type="application/json",
|
||||
)
|
||||
async def document_sets_resource() -> str:
|
||||
"""Return the list of document sets the user can filter searches by."""
|
||||
|
||||
access_token = require_access_token()
|
||||
|
||||
document_sets = sorted(
|
||||
await get_accessible_document_sets(access_token), key=lambda entry: entry.name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Onyx MCP Server: document_sets resource returning %s entries",
|
||||
len(document_sets),
|
||||
)
|
||||
|
||||
# FastMCP 3.2+ requires str/bytes/list[ResourceContent] — it no longer
|
||||
# auto-serializes; serialize to JSON ourselves.
|
||||
return json.dumps([entry.model_dump(mode="json") for entry in document_sets])
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import json
|
||||
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_indexed_sources
|
||||
@@ -21,7 +21,7 @@ logger = setup_logger()
|
||||
),
|
||||
mime_type="application/json",
|
||||
)
|
||||
async def indexed_sources_resource() -> dict[str, Any]:
|
||||
async def indexed_sources_resource() -> str:
|
||||
"""Return the list of indexed source types for search filtering."""
|
||||
|
||||
access_token = require_access_token()
|
||||
@@ -33,6 +33,6 @@ async def indexed_sources_resource() -> dict[str, Any]:
|
||||
len(sources),
|
||||
)
|
||||
|
||||
return {
|
||||
"indexed_sources": sorted(sources),
|
||||
}
|
||||
# FastMCP 3.2+ requires str/bytes/list[ResourceContent] — it no longer
|
||||
# auto-serializes; serialize to JSON ourselves.
|
||||
return json.dumps(sorted(sources))
|
||||
|
||||
@@ -4,12 +4,23 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_http_client
|
||||
from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolRequest
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolResponse
|
||||
from onyx.server.features.web_search.models import WebSearchToolRequest
|
||||
from onyx.server.features.web_search.models import WebSearchToolResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -17,6 +28,43 @@ from onyx.utils.variable_functionality import global_version
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# CE search falls through to the chat endpoint, which invokes an LLM — the
|
||||
# default 60s client timeout is not enough for a real RAG-backed response.
|
||||
_CE_SEARCH_TIMEOUT_SECONDS = 300.0
|
||||
|
||||
|
||||
async def _post_model(
|
||||
url: str,
|
||||
body: BaseModel,
|
||||
access_token: AccessToken,
|
||||
timeout: float | None = None,
|
||||
) -> httpx.Response:
|
||||
"""POST a Pydantic model as JSON to the Onyx backend."""
|
||||
return await get_http_client().post(
|
||||
url,
|
||||
content=body.model_dump_json(exclude_unset=True),
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token.token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=timeout if timeout is not None else httpx.USE_CLIENT_DEFAULT,
|
||||
)
|
||||
|
||||
|
||||
def _project_doc(doc: SearchDoc, content: str | None) -> dict[str, Any]:
|
||||
"""Project a backend search doc into the MCP wire shape.
|
||||
|
||||
Accepts SearchDocWithContent (EE) too since it extends SearchDoc.
|
||||
"""
|
||||
return {
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"content": content,
|
||||
"source_type": doc.source_type.value,
|
||||
"link": doc.link,
|
||||
"score": doc.score,
|
||||
}
|
||||
|
||||
|
||||
def _extract_error_detail(response: httpx.Response) -> str:
|
||||
"""Extract a human-readable error message from a failed backend response.
|
||||
|
||||
@@ -36,6 +84,7 @@ def _extract_error_detail(response: httpx.Response) -> str:
|
||||
async def search_indexed_documents(
|
||||
query: str,
|
||||
source_types: list[str] | None = None,
|
||||
document_set_names: list[str] | None = None,
|
||||
time_cutoff: str | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
@@ -53,6 +102,10 @@ async def search_indexed_documents(
|
||||
In EE mode, the dedicated search endpoint is used instead.
|
||||
|
||||
To find a list of available sources, use the `indexed_sources` resource.
|
||||
`document_set_names` restricts results to documents belonging to the named
|
||||
Document Sets — useful for scoping queries to a curated subset of the
|
||||
knowledge base (e.g. to isolate knowledge between agents). Use the
|
||||
`document_sets` resource to discover accessible set names.
|
||||
Returns chunks of text as search results with snippets, scores, and metadata.
|
||||
|
||||
Example usage:
|
||||
@@ -60,15 +113,23 @@ async def search_indexed_documents(
|
||||
{
|
||||
"query": "What is the latest status of PROJ-1234 and what is the next development item?",
|
||||
"source_types": ["jira", "google_drive", "github"],
|
||||
"document_set_names": ["Engineering Wiki"],
|
||||
"time_cutoff": "2025-11-24T00:00:00Z",
|
||||
"limit": 10,
|
||||
}
|
||||
```
|
||||
"""
|
||||
logger.info(
|
||||
f"Onyx MCP Server: document search: query='{query}', sources={source_types}, limit={limit}"
|
||||
f"Onyx MCP Server: document search: query='{query}', sources={source_types}, "
|
||||
f"document_sets={document_set_names}, limit={limit}"
|
||||
)
|
||||
|
||||
# Normalize empty list inputs to None so downstream filter construction is
|
||||
# consistent — BaseFilters treats [] as "match zero" which differs from
|
||||
# "no filter" (None).
|
||||
source_types = source_types or None
|
||||
document_set_names = document_set_names or None
|
||||
|
||||
# Parse time_cutoff string to datetime if provided
|
||||
time_cutoff_dt: datetime | None = None
|
||||
if time_cutoff:
|
||||
@@ -81,9 +142,6 @@ async def search_indexed_documents(
|
||||
# Continue with no time_cutoff instead of returning an error
|
||||
time_cutoff_dt = None
|
||||
|
||||
# Initialize source_type_enums early to avoid UnboundLocalError
|
||||
source_type_enums: list[DocumentSource] | None = None
|
||||
|
||||
# Get authenticated user from FastMCP's access token
|
||||
access_token = require_access_token()
|
||||
|
||||
@@ -117,6 +175,7 @@ async def search_indexed_documents(
|
||||
|
||||
# Convert source_types strings to DocumentSource enums if provided
|
||||
# Invalid values will be handled by the API server
|
||||
source_type_enums: list[DocumentSource] | None = None
|
||||
if source_types is not None:
|
||||
source_type_enums = []
|
||||
for src in source_types:
|
||||
@@ -127,83 +186,83 @@ async def search_indexed_documents(
|
||||
f"Onyx MCP Server: Invalid source type '{src}' - will be ignored by server"
|
||||
)
|
||||
|
||||
# Build filters dict only with non-None values
|
||||
filters: dict[str, Any] | None = None
|
||||
if source_type_enums or time_cutoff_dt:
|
||||
filters = {}
|
||||
if source_type_enums:
|
||||
filters["source_type"] = [src.value for src in source_type_enums]
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
filters: BaseFilters | None = None
|
||||
if source_type_enums or document_set_names or time_cutoff_dt:
|
||||
filters = BaseFilters(
|
||||
source_type=source_type_enums,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=time_cutoff_dt,
|
||||
)
|
||||
|
||||
is_ee = global_version.is_ee_version()
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
is_ee = global_version.is_ee_version()
|
||||
|
||||
search_request: dict[str, Any]
|
||||
request: BaseModel
|
||||
if is_ee:
|
||||
# EE: use the dedicated search endpoint (no LLM invocation)
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
# EE: use the dedicated search endpoint (no LLM invocation).
|
||||
# Lazy import so CE deployments that strip ee/ never load this module.
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
|
||||
request = SendSearchQueryRequest(
|
||||
search_query=query,
|
||||
filters=filters,
|
||||
num_docs_fed_to_llm_selection=limit,
|
||||
run_query_expansion=False,
|
||||
include_content=True,
|
||||
stream=False,
|
||||
)
|
||||
endpoint = f"{base_url}/search/send-search-message"
|
||||
error_key = "error"
|
||||
docs_key = "search_docs"
|
||||
content_field = "content"
|
||||
else:
|
||||
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
|
||||
search_request = {
|
||||
"message": query,
|
||||
"stream": False,
|
||||
"chat_session_info": {},
|
||||
}
|
||||
if filters:
|
||||
search_request["internal_search_filters"] = filters
|
||||
request = SendMessageRequest(
|
||||
message=query,
|
||||
stream=False,
|
||||
chat_session_info=ChatSessionCreationRequest(),
|
||||
internal_search_filters=filters,
|
||||
)
|
||||
endpoint = f"{base_url}/chat/send-chat-message"
|
||||
error_key = "error_msg"
|
||||
docs_key = "top_documents"
|
||||
content_field = "blurb"
|
||||
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
response = await _post_model(
|
||||
endpoint,
|
||||
json=search_request,
|
||||
headers=auth_headers,
|
||||
request,
|
||||
access_token,
|
||||
timeout=None if is_ee else _CE_SEARCH_TIMEOUT_SECONDS,
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": error_detail,
|
||||
}
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get(error_key):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get(error_key),
|
||||
"error": _extract_error_detail(response),
|
||||
}
|
||||
|
||||
documents = [
|
||||
{
|
||||
"semantic_identifier": doc.get("semantic_identifier"),
|
||||
"content": doc.get(content_field),
|
||||
"source_type": doc.get("source_type"),
|
||||
"link": doc.get("link"),
|
||||
"score": doc.get("score"),
|
||||
}
|
||||
for doc in result.get(docs_key, [])
|
||||
]
|
||||
if is_ee:
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
|
||||
ee_payload = SearchFullResponse.model_validate_json(response.content)
|
||||
if ee_payload.error:
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": ee_payload.error,
|
||||
}
|
||||
documents = [
|
||||
_project_doc(doc, doc.content) for doc in ee_payload.search_docs
|
||||
]
|
||||
else:
|
||||
ce_payload = ChatFullResponse.model_validate_json(response.content)
|
||||
if ce_payload.error_msg:
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": ce_payload.error_msg,
|
||||
}
|
||||
documents = [
|
||||
_project_doc(doc, doc.blurb) for doc in ce_payload.top_documents
|
||||
]
|
||||
|
||||
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
|
||||
# `limit` only caps the returned list; fewer results may be returned if the
|
||||
@@ -252,23 +311,20 @@ async def search_web(
|
||||
access_token = require_access_token()
|
||||
|
||||
try:
|
||||
request_payload = {"queries": [query], "max_results": limit}
|
||||
response = await get_http_client().post(
|
||||
response = await _post_model(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/search-lite",
|
||||
json=request_payload,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
WebSearchToolRequest(queries=[query], max_results=limit),
|
||||
access_token,
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"error": _extract_error_detail(response),
|
||||
"results": [],
|
||||
"query": query,
|
||||
}
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
payload = WebSearchToolResponse.model_validate_json(response.content)
|
||||
return {
|
||||
"results": results,
|
||||
"results": [result.model_dump(mode="json") for result in payload.results],
|
||||
"query": query,
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -305,21 +361,19 @@ async def open_urls(
|
||||
access_token = require_access_token()
|
||||
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
response = await _post_model(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/open-urls",
|
||||
json={"urls": urls},
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
OpenUrlsToolRequest(urls=urls),
|
||||
access_token,
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"error": _extract_error_detail(response),
|
||||
"results": [],
|
||||
}
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
payload = OpenUrlsToolResponse.model_validate_json(response.content)
|
||||
return {
|
||||
"results": results,
|
||||
"results": [result.model_dump(mode="json") for result in payload.results],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: URL fetch error: {e}", exc_info=True)
|
||||
|
||||
@@ -5,10 +5,24 @@ from __future__ import annotations
|
||||
import httpx
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from pydantic import BaseModel
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
|
||||
|
||||
class DocumentSetEntry(BaseModel):
|
||||
"""Minimal document-set shape surfaced to MCP clients.
|
||||
|
||||
Projected from the backend's DocumentSetSummary to avoid coupling MCP to
|
||||
admin-only fields (cc-pair summaries, federated connectors, etc.).
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Shared HTTP client reused across requests
|
||||
@@ -84,3 +98,32 @@ async def get_indexed_sources(
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(f"Failed to fetch indexed sources: {exc}") from exc
|
||||
|
||||
|
||||
_DOCUMENT_SET_ENTRIES_ADAPTER = TypeAdapter(list[DocumentSetEntry])
|
||||
|
||||
|
||||
async def get_accessible_document_sets(
|
||||
access_token: AccessToken,
|
||||
) -> list[DocumentSetEntry]:
|
||||
"""Fetch document sets accessible to the current user."""
|
||||
headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
try:
|
||||
response = await get_http_client().get(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/manage/document-set",
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return _DOCUMENT_SET_ENTRIES_ADAPTER.validate_json(response.content)
|
||||
except (httpx.HTTPStatusError, httpx.RequestError, ValueError):
|
||||
logger.error(
|
||||
"Onyx MCP Server: Failed to fetch document sets",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Onyx MCP Server: Unexpected error fetching document sets",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(f"Failed to fetch document sets: {exc}") from exc
|
||||
|
||||
@@ -11,6 +11,8 @@ All public functions no-op in single-tenant mode (`MULTI_TENANT=False`).
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
@@ -26,6 +28,40 @@ logger = setup_logger()
|
||||
_SET_KEY = "active_tenants"
|
||||
|
||||
|
||||
# --- Prometheus metrics ---
|
||||
|
||||
_active_set_size = Gauge(
|
||||
"onyx_tenant_work_gating_active_set_size",
|
||||
"Current cardinality of the active_tenants sorted set (updated once per "
|
||||
"generator invocation when the gate reads it).",
|
||||
)
|
||||
|
||||
_marked_total = Counter(
|
||||
"onyx_tenant_work_gating_marked_total",
|
||||
"Writes into active_tenants, labelled by caller.",
|
||||
["caller"],
|
||||
)
|
||||
|
||||
_skipped_total = Counter(
|
||||
"onyx_tenant_work_gating_skipped_total",
|
||||
"Per-tenant fanouts skipped by the gate (enforce mode only), by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
_would_skip_total = Counter(
|
||||
"onyx_tenant_work_gating_would_skip_total",
|
||||
"Per-tenant fanouts that would have been skipped if enforce were on "
|
||||
"(shadow counter), by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
_full_fanout_total = Counter(
|
||||
"onyx_tenant_work_gating_full_fanout_total",
|
||||
"Generator invocations that bypassed the gate for a full fanout cycle, by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@@ -54,10 +90,14 @@ def mark_tenant_active(tenant_id: str) -> None:
|
||||
logger.exception(f"mark_tenant_active failed: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
def maybe_mark_tenant_active(tenant_id: str) -> None:
|
||||
def maybe_mark_tenant_active(tenant_id: str, caller: str = "unknown") -> None:
|
||||
"""Convenience wrapper for writer call sites: records the tenant only
|
||||
when the feature flag is on. Fully defensive — never raises, so a Redis
|
||||
outage or flag-read failure can't abort the calling task."""
|
||||
outage or flag-read failure can't abort the calling task.
|
||||
|
||||
`caller` labels the Prometheus counter so a dashboard can show which
|
||||
consumer is firing the hook most.
|
||||
"""
|
||||
try:
|
||||
# Local import to avoid a module-load cycle: OnyxRuntime imports
|
||||
# onyx.redis.redis_pool, so a top-level import here would wedge on
|
||||
@@ -67,10 +107,44 @@ def maybe_mark_tenant_active(tenant_id: str) -> None:
|
||||
if not OnyxRuntime.get_tenant_work_gating_enabled():
|
||||
return
|
||||
mark_tenant_active(tenant_id)
|
||||
_marked_total.labels(caller=caller).inc()
|
||||
except Exception:
|
||||
logger.exception(f"maybe_mark_tenant_active failed: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
def observe_active_set_size() -> int | None:
|
||||
"""Return `ZCARD active_tenants` and update the Prometheus gauge. Call
|
||||
from the gate generator once per invocation so the dashboard has a
|
||||
live reading.
|
||||
|
||||
Returns `None` on Redis error or in single-tenant mode; callers can
|
||||
tolerate that (gauge simply doesn't update)."""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
try:
|
||||
size = cast(int, _client().zcard(_SET_KEY))
|
||||
_active_set_size.set(size)
|
||||
return size
|
||||
except Exception:
|
||||
logger.exception("observe_active_set_size failed")
|
||||
return None
|
||||
|
||||
|
||||
def record_gate_decision(task_name: str, skipped: bool) -> None:
|
||||
"""Increment skip counters from the gate generator. Called once per
|
||||
tenant that the gate would skip. Always increments the shadow counter;
|
||||
increments the enforced counter only when `skipped=True`."""
|
||||
_would_skip_total.labels(task=task_name).inc()
|
||||
if skipped:
|
||||
_skipped_total.labels(task=task_name).inc()
|
||||
|
||||
|
||||
def record_full_fanout_cycle(task_name: str) -> None:
|
||||
"""Increment the full-fanout counter. Called once per generator
|
||||
invocation where the gate is bypassed (interval elapsed OR fail-open)."""
|
||||
_full_fanout_total.labels(task=task_name).inc()
|
||||
|
||||
|
||||
def get_active_tenants(ttl_seconds: int) -> set[str] | None:
|
||||
"""Return tenants whose last-seen timestamp is within `ttl_seconds` of
|
||||
now.
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
@@ -49,6 +50,7 @@ def get_opensearch_retrieval_status(
|
||||
enable_opensearch_retrieval = get_opensearch_retrieval_state(db_session)
|
||||
return OpenSearchRetrievalStatusResponse(
|
||||
enable_opensearch_retrieval=enable_opensearch_retrieval,
|
||||
toggling_retrieval_is_disabled=ONYX_DISABLE_VESPA,
|
||||
)
|
||||
|
||||
|
||||
@@ -63,4 +65,5 @@ def set_opensearch_retrieval_status(
|
||||
)
|
||||
return OpenSearchRetrievalStatusResponse(
|
||||
enable_opensearch_retrieval=request.enable_opensearch_retrieval,
|
||||
toggling_retrieval_is_disabled=ONYX_DISABLE_VESPA,
|
||||
)
|
||||
|
||||
@@ -19,3 +19,4 @@ class OpenSearchRetrievalStatusRequest(BaseModel):
|
||||
class OpenSearchRetrievalStatusResponse(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
enable_opensearch_retrieval: bool
|
||||
toggling_retrieval_is_disabled: bool = False
|
||||
|
||||
@@ -395,6 +395,15 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
|
||||
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
|
||||
to the Celery event stream via a single persistent connection.
|
||||
|
||||
TODO: every monitoring pod subscribes to the cluster-wide Celery event
|
||||
stream, so each replica reports health for *all* workers in the cluster,
|
||||
not just itself. Prometheus distinguishes the replicas via the ``instance``
|
||||
label, so this doesn't break scraping, but it means N monitoring replicas
|
||||
do N× the work and may emit slightly inconsistent snapshots of the same
|
||||
cluster. The proper fix is to have each worker expose its own health (or
|
||||
to elect a single monitoring replica as the reporter) rather than
|
||||
broadcasting the full cluster view from every monitoring pod.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = 30.0) -> None:
|
||||
@@ -413,10 +422,16 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers with recent heartbeats",
|
||||
)
|
||||
# Celery hostnames are ``{worker_type}@{nodename}`` (see supervisord.conf).
|
||||
# Emitting only the worker_type as a label causes N replicas of the same
|
||||
# type to collapse into identical timeseries within a single scrape,
|
||||
# which Prometheus rejects as "duplicate sample for timestamp". Split
|
||||
# the pieces into separate labels so each replica is distinct; callers
|
||||
# can still ``sum by (worker_type)`` to recover the old aggregated view.
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
"Whether a specific Celery worker is alive (1=up, 0=down)",
|
||||
labels=["worker"],
|
||||
labels=["worker_type", "hostname"],
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -424,11 +439,15 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
alive_count = sum(1 for alive in status.values() if alive)
|
||||
active_workers.add_metric([], alive_count)
|
||||
|
||||
for hostname in sorted(status):
|
||||
# Use short name (before @) for single-host deployments,
|
||||
# full hostname when multiple hosts share a worker type.
|
||||
label = hostname.split("@")[0]
|
||||
worker_up.add_metric([label], 1 if status[hostname] else 0)
|
||||
for full_hostname in sorted(status):
|
||||
worker_type, sep, host = full_hostname.partition("@")
|
||||
if not sep:
|
||||
# Hostname didn't contain "@" — fall back to using the
|
||||
# whole string as the hostname with an empty type.
|
||||
worker_type, host = "", full_hostname
|
||||
worker_up.add_metric(
|
||||
[worker_type, host], 1 if status[full_hostname] else 0
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.embedding_configs import SUPPORTED_EMBEDDING_MODELS
|
||||
@@ -126,10 +127,11 @@ def setup_onyx(
|
||||
"DISABLE_VECTOR_DB is set — skipping document index setup and embedding model warm-up."
|
||||
)
|
||||
else:
|
||||
# Ensure Vespa is setup correctly, this step is relatively near the end
|
||||
# because Vespa takes a bit of time to start up
|
||||
# Ensure the document indices are setup correctly. This step is
|
||||
# relatively near the end because Vespa takes a bit of time to start up.
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
# This flow is for setting up the document index so we get all indices here.
|
||||
# This flow is for setting up the document index so we get all indices
|
||||
# here.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings,
|
||||
secondary_search_settings,
|
||||
@@ -335,7 +337,7 @@ def setup_multitenant_onyx() -> None:
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA:
|
||||
if not MANAGED_VESPA and not ONYX_DISABLE_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,10 @@ import pytest
|
||||
|
||||
from onyx.configs.constants import BlobType
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -111,15 +113,18 @@ def test_blob_s3_connector(
|
||||
|
||||
for doc in all_docs:
|
||||
section = doc.sections[0]
|
||||
assert isinstance(section, TextSection)
|
||||
|
||||
file_extension = get_file_ext(doc.semantic_identifier)
|
||||
if file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
|
||||
if is_tabular_file(doc.semantic_identifier):
|
||||
assert isinstance(section, TabularSection)
|
||||
assert len(section.text) > 0
|
||||
continue
|
||||
|
||||
# unknown extension
|
||||
assert len(section.text) == 0
|
||||
assert isinstance(section, TextSection)
|
||||
file_extension = get_file_ext(doc.semantic_identifier)
|
||||
if file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
|
||||
assert len(section.text) > 0
|
||||
else:
|
||||
assert len(section.text) == 0
|
||||
|
||||
|
||||
@patch(
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Tests for `cloud_beat_task_generator`'s tenant work-gating logic.
|
||||
|
||||
Exercises the gate-read path end-to-end against real Redis. The Celery
|
||||
`.app.send_task` is mocked so we can count dispatches without actually
|
||||
sending messages.
|
||||
|
||||
Requires a running Redis instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest \
|
||||
backend/tests/external_dependency_unit/tenant_work_gating/test_gate_generator.py
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.background.celery.tasks.cloud import tasks as cloud_tasks
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis import redis_tenant_work_gating as twg
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_tenant_work_gating import _SET_KEY
|
||||
from onyx.redis.redis_tenant_work_gating import mark_tenant_active
|
||||
|
||||
|
||||
_TENANT_A = "tenant_aaaa0000-0000-0000-0000-000000000001"
|
||||
_TENANT_B = "tenant_bbbb0000-0000-0000-0000-000000000002"
|
||||
_TENANT_C = "tenant_cccc0000-0000-0000-0000-000000000003"
|
||||
_ALL_TEST_TENANTS = [_TENANT_A, _TENANT_B, _TENANT_C]
|
||||
_FANOUT_KEY_PREFIX = cloud_tasks._FULL_FANOUT_TIMESTAMP_KEY_PREFIX
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _multi_tenant_true() -> Generator[None, None, None]:
|
||||
with patch.object(twg, "MULTI_TENANT", True):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_redis() -> Generator[None, None, None]:
|
||||
"""Clear the active set AND the per-task full-fanout timestamp so each
|
||||
test starts fresh."""
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
r.delete(_SET_KEY)
|
||||
r.delete(f"{_FANOUT_KEY_PREFIX}:test_task")
|
||||
r.delete("runtime:tenant_work_gating:enabled")
|
||||
r.delete("runtime:tenant_work_gating:enforce")
|
||||
yield
|
||||
r.delete(_SET_KEY)
|
||||
r.delete(f"{_FANOUT_KEY_PREFIX}:test_task")
|
||||
r.delete("runtime:tenant_work_gating:enabled")
|
||||
r.delete("runtime:tenant_work_gating:enforce")
|
||||
|
||||
|
||||
def _invoke_generator(
|
||||
*,
|
||||
work_gated: bool,
|
||||
enabled: bool,
|
||||
enforce: bool,
|
||||
tenant_ids: list[str],
|
||||
full_fanout_interval_seconds: int = 1200,
|
||||
ttl_seconds: int = 1800,
|
||||
) -> MagicMock:
|
||||
"""Helper: call the generator with runtime flags fixed and the Celery
|
||||
app mocked. Returns the mock so callers can assert on send_task calls."""
|
||||
mock_app = MagicMock()
|
||||
# The task binds `self` = the task itself when invoked via `.run()`;
|
||||
# patch its `.app` so `self.app.send_task` routes to our mock.
|
||||
with (
|
||||
patch.object(cloud_tasks.cloud_beat_task_generator, "app", mock_app),
|
||||
patch.object(cloud_tasks, "get_all_tenant_ids", return_value=list(tenant_ids)),
|
||||
patch.object(cloud_tasks, "get_gated_tenants", return_value=set()),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_enabled",
|
||||
return_value=enabled,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_enforce",
|
||||
return_value=enforce,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_full_fanout_interval_seconds",
|
||||
return_value=full_fanout_interval_seconds,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_ttl_seconds",
|
||||
return_value=ttl_seconds,
|
||||
),
|
||||
):
|
||||
cloud_tasks.cloud_beat_task_generator.run(
|
||||
task_name="test_task",
|
||||
work_gated=work_gated,
|
||||
)
|
||||
return mock_app
|
||||
|
||||
|
||||
def _dispatched_tenants(mock_app: MagicMock) -> list[str]:
|
||||
"""Pull tenant_ids out of each send_task call for assertion."""
|
||||
return [c.kwargs["kwargs"]["tenant_id"] for c in mock_app.send_task.call_args_list]
|
||||
|
||||
|
||||
def _seed_recent_full_fanout_timestamp() -> None:
|
||||
"""Pre-seed the per-task timestamp so the interval-elapsed branch
|
||||
reports False, i.e. the gate enforces normally instead of going into
|
||||
full-fanout on first invocation."""
|
||||
import time as _t
|
||||
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
r.set(f"{_FANOUT_KEY_PREFIX}:test_task", str(int(_t.time() * 1000)))
|
||||
|
||||
|
||||
def test_enforce_skips_unmarked_tenants() -> None:
|
||||
"""With enable+enforce on (interval NOT elapsed), only tenants in the
|
||||
active set get dispatched."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert dispatched == [_TENANT_A]
|
||||
|
||||
|
||||
def test_shadow_mode_dispatches_all_tenants() -> None:
|
||||
"""enabled=True, enforce=False: gate computes skip but still dispatches."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=False,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_full_fanout_cycle_dispatches_all_tenants() -> None:
|
||||
"""First invocation (no prior timestamp → interval considered elapsed)
|
||||
counts as full-fanout; every tenant gets dispatched even under enforce."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_redis_unavailable_fails_open() -> None:
|
||||
"""When `get_active_tenants` returns None (simulated Redis outage) the
|
||||
gate treats the invocation as full-fanout and dispatches everyone —
|
||||
even when the interval hasn't elapsed and enforce is on."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
with patch.object(cloud_tasks, "get_active_tenants", return_value=None):
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_work_gated_false_bypasses_gate_entirely() -> None:
|
||||
"""Beat templates that don't opt in (`work_gated=False`) never consult
|
||||
the set — no matter the flag state."""
|
||||
# Even with enforce on and nothing in the set, all tenants dispatch.
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=False,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_gate_disabled_dispatches_everyone_regardless_of_enforce() -> None:
|
||||
"""enabled=False means the gate isn't computed — dispatch is unchanged."""
|
||||
# Intentionally don't add anyone to the set.
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=False,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
@@ -16,12 +16,14 @@ from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import CallToolResult
|
||||
from mcp.types import TextContent
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from tests.integration.common_utils.constants import MCP_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.pat import PATManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
@@ -34,6 +36,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
# Constants
|
||||
MCP_SEARCH_TOOL = "search_indexed_documents"
|
||||
INDEXED_SOURCES_RESOURCE_URI = "resource://indexed_sources"
|
||||
DOCUMENT_SETS_RESOURCE_URI = "resource://document_sets"
|
||||
DEFAULT_SEARCH_LIMIT = 5
|
||||
STREAMABLE_HTTP_URL = f"{MCP_SERVER_URL.rstrip('/')}/?transportType=streamable-http"
|
||||
|
||||
@@ -73,19 +76,22 @@ def _extract_tool_payload(result: CallToolResult) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _call_search_tool(
|
||||
headers: dict[str, str], query: str, limit: int = DEFAULT_SEARCH_LIMIT
|
||||
headers: dict[str, str],
|
||||
query: str,
|
||||
limit: int = DEFAULT_SEARCH_LIMIT,
|
||||
document_set_names: list[str] | None = None,
|
||||
) -> CallToolResult:
|
||||
"""Call the search_indexed_documents tool via MCP."""
|
||||
|
||||
async def _action(session: ClientSession) -> CallToolResult:
|
||||
await session.initialize()
|
||||
return await session.call_tool(
|
||||
MCP_SEARCH_TOOL,
|
||||
{
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
arguments: dict[str, Any] = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
}
|
||||
if document_set_names is not None:
|
||||
arguments["document_set_names"] = document_set_names
|
||||
return await session.call_tool(MCP_SEARCH_TOOL, arguments)
|
||||
|
||||
return _run_with_mcp_session(headers, _action)
|
||||
|
||||
@@ -238,3 +244,106 @@ def test_mcp_search_respects_acl_filters(
|
||||
blocked_payload = _extract_tool_payload(blocked_result)
|
||||
assert blocked_payload["total_results"] == 0
|
||||
assert blocked_payload["documents"] == []
|
||||
|
||||
|
||||
def test_mcp_search_filters_by_document_set(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Passing document_set_names should scope results to the named set."""
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
api_key = APIKeyManager.create(user_performing_action=admin_user)
|
||||
cc_pair_in_set = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_out_of_set = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
shared_phrase = "document-set-filter-shared-phrase"
|
||||
in_set_content = f"{shared_phrase} inside curated set"
|
||||
out_of_set_content = f"{shared_phrase} outside curated set"
|
||||
|
||||
_seed_document_and_wait_for_indexing(
|
||||
cc_pair=cc_pair_in_set,
|
||||
content=in_set_content,
|
||||
api_key=api_key,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
_seed_document_and_wait_for_indexing(
|
||||
cc_pair=cc_pair_out_of_set,
|
||||
content=out_of_set_content,
|
||||
api_key=api_key,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
doc_set = DocumentSetManager.create(
|
||||
cc_pair_ids=[cc_pair_in_set.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
DocumentSetManager.wait_for_sync(
|
||||
user_performing_action=admin_user,
|
||||
document_sets_to_check=[doc_set],
|
||||
)
|
||||
|
||||
headers = _auth_headers(admin_user, name="mcp-doc-set-filter")
|
||||
|
||||
# The document_sets resource should surface the newly created set so MCP
|
||||
# clients can discover which values to pass to document_set_names.
|
||||
async def _list_resources(session: ClientSession) -> Any:
|
||||
await session.initialize()
|
||||
resources = await session.list_resources()
|
||||
contents = await session.read_resource(AnyUrl(DOCUMENT_SETS_RESOURCE_URI))
|
||||
return resources, contents
|
||||
|
||||
resources_result, doc_sets_contents = _run_with_mcp_session(
|
||||
headers, _list_resources
|
||||
)
|
||||
resource_uris = {str(resource.uri) for resource in resources_result.resources}
|
||||
assert DOCUMENT_SETS_RESOURCE_URI in resource_uris
|
||||
doc_sets_payload = json.loads(doc_sets_contents.contents[0].text)
|
||||
exposed_names = {entry["name"] for entry in doc_sets_payload}
|
||||
assert doc_set.name in exposed_names
|
||||
|
||||
# Without the filter both documents are visible.
|
||||
unfiltered_payload = _extract_tool_payload(
|
||||
_call_search_tool(headers, shared_phrase, limit=10)
|
||||
)
|
||||
unfiltered_contents = [
|
||||
doc.get("content") or "" for doc in unfiltered_payload["documents"]
|
||||
]
|
||||
assert any(in_set_content in content for content in unfiltered_contents)
|
||||
assert any(out_of_set_content in content for content in unfiltered_contents)
|
||||
|
||||
# With the document set filter only the in-set document is returned.
|
||||
filtered_payload = _extract_tool_payload(
|
||||
_call_search_tool(
|
||||
headers,
|
||||
shared_phrase,
|
||||
limit=10,
|
||||
document_set_names=[doc_set.name],
|
||||
)
|
||||
)
|
||||
filtered_contents = [
|
||||
doc.get("content") or "" for doc in filtered_payload["documents"]
|
||||
]
|
||||
assert filtered_payload["total_results"] >= 1
|
||||
assert any(in_set_content in content for content in filtered_contents)
|
||||
assert all(out_of_set_content not in content for content in filtered_contents)
|
||||
|
||||
# An empty document_set_names should behave like "no filter" (normalized
|
||||
# to None), not "match zero sets".
|
||||
empty_list_payload = _extract_tool_payload(
|
||||
_call_search_tool(
|
||||
headers,
|
||||
shared_phrase,
|
||||
limit=10,
|
||||
document_set_names=[],
|
||||
)
|
||||
)
|
||||
empty_list_contents = [
|
||||
doc.get("content") or "" for doc in empty_list_payload["documents"]
|
||||
]
|
||||
assert any(in_set_content in content for content in empty_list_contents)
|
||||
assert any(out_of_set_content in content for content in empty_list_contents)
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Unit tests for SharepointConnector.load_credentials sp_tenant_domain resolution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
|
||||
SITE_URL = "https://mytenant.sharepoint.com/sites/MySite"
|
||||
EXPECTED_TENANT_DOMAIN = "mytenant"
|
||||
|
||||
CLIENT_SECRET_CREDS = {
|
||||
"authentication_method": "client_secret",
|
||||
"sp_client_id": "fake-client-id",
|
||||
"sp_client_secret": "fake-client-secret",
|
||||
"sp_directory_id": "fake-directory-id",
|
||||
}
|
||||
|
||||
CERTIFICATE_CREDS = {
|
||||
"authentication_method": "certificate",
|
||||
"sp_client_id": "fake-client-id",
|
||||
"sp_directory_id": "fake-directory-id",
|
||||
"sp_private_key": base64.b64encode(b"fake-pfx-data").decode(),
|
||||
"sp_certificate_password": "fake-password",
|
||||
}
|
||||
|
||||
|
||||
def _make_mock_msal() -> MagicMock:
|
||||
mock_app = MagicMock()
|
||||
mock_app.acquire_token_for_client.return_value = {"access_token": "fake-token"}
|
||||
return mock_app
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_client_secret_with_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
) -> None:
|
||||
"""client_secret auth + include_site_pages=True must resolve sp_tenant_domain."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=True)
|
||||
|
||||
connector.load_credentials(CLIENT_SECRET_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_client_secret_without_site_pages_still_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
) -> None:
|
||||
"""client_secret auth + include_site_pages=False must still resolve sp_tenant_domain
|
||||
because _create_rest_client_context is also called for drive items."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=False)
|
||||
|
||||
connector.load_credentials(CLIENT_SECRET_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.load_certificate_from_pfx")
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_certificate_with_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
mock_load_cert: MagicMock,
|
||||
) -> None:
|
||||
"""certificate auth + include_site_pages=True must resolve sp_tenant_domain."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
mock_load_cert.return_value = MagicMock()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=True)
|
||||
|
||||
connector.load_credentials(CERTIFICATE_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.load_certificate_from_pfx")
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_certificate_without_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
mock_load_cert: MagicMock,
|
||||
) -> None:
|
||||
"""certificate auth + include_site_pages=False must still resolve sp_tenant_domain
|
||||
because _create_rest_client_context is also called for drive items."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
mock_load_cert.return_value = MagicMock()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=False)
|
||||
|
||||
connector.load_credentials(CERTIFICATE_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
@@ -321,6 +322,17 @@ class TestXlsxSheetExtraction:
|
||||
sheets = xlsx_sheet_extraction(bad_file, file_name="~$temp.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_known_openpyxl_bug_max_value_returns_empty(self) -> None:
|
||||
"""openpyxl's strict descriptor validation rejects font family
|
||||
values >14 with 'Max value is 14'. Treat as a known openpyxl bug
|
||||
and skip the file rather than fail the whole connector batch."""
|
||||
with patch(
|
||||
"onyx.file_processing.extract_file_text.openpyxl.load_workbook",
|
||||
side_effect=ValueError("Max value is 14"),
|
||||
):
|
||||
sheets = xlsx_sheet_extraction(io.BytesIO(b""), file_name="bad_font.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_csv_content_matches_xlsx_to_text_per_sheet(self) -> None:
|
||||
"""For a single-sheet workbook, xlsx_to_text output should equal
|
||||
the csv_text from xlsx_sheet_extraction — they share the same
|
||||
|
||||
@@ -129,12 +129,36 @@ class TestWorkerHealthCollector:
|
||||
up = families[1]
|
||||
assert up.name == "onyx_celery_worker_up"
|
||||
assert len(up.samples) == 3
|
||||
# Labels use short names (before @)
|
||||
labels = {s.labels["worker"] for s in up.samples}
|
||||
assert labels == {"primary", "docfetching", "monitoring"}
|
||||
label_pairs = {
|
||||
(s.labels["worker_type"], s.labels["hostname"]) for s in up.samples
|
||||
}
|
||||
assert label_pairs == {
|
||||
("primary", "host1"),
|
||||
("docfetching", "host1"),
|
||||
("monitoring", "host1"),
|
||||
}
|
||||
for sample in up.samples:
|
||||
assert sample.value == 1
|
||||
|
||||
def test_replicas_of_same_worker_type_are_distinct(self) -> None:
|
||||
"""Regression: ``docprocessing@pod-1`` and ``docprocessing@pod-2`` must
|
||||
produce separate samples, not collapse into one duplicate-timestamp
|
||||
series.
|
||||
"""
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-1"})
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-2"})
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-3"})
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
up = collector.collect()[1]
|
||||
assert len(up.samples) == 3
|
||||
hostnames = {s.labels["hostname"] for s in up.samples}
|
||||
assert hostnames == {"pod-1", "pod-2", "pod-3"}
|
||||
assert all(s.labels["worker_type"] == "docprocessing" for s in up.samples)
|
||||
|
||||
def test_reports_dead_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
@@ -151,9 +175,9 @@ class TestWorkerHealthCollector:
|
||||
assert active.samples[0].value == 1
|
||||
|
||||
up = families[1]
|
||||
samples_by_name = {s.labels["worker"]: s.value for s in up.samples}
|
||||
assert samples_by_name["primary"] == 1
|
||||
assert samples_by_name["monitoring"] == 0
|
||||
samples_by_type = {s.labels["worker_type"]: s.value for s in up.samples}
|
||||
assert samples_by_type["primary"] == 1
|
||||
assert samples_by_type["monitoring"] == 0
|
||||
|
||||
def test_empty_monitor_returns_zero(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
|
||||
@@ -152,6 +152,11 @@ const nextConfig = {
|
||||
destination: "/ee/agents/:path*",
|
||||
permanent: true,
|
||||
},
|
||||
{
|
||||
source: "/admin/configuration/llm",
|
||||
destination: "/admin/configuration/language-models",
|
||||
permanent: true,
|
||||
},
|
||||
];
|
||||
},
|
||||
};
|
||||
|
||||
@@ -15,6 +15,9 @@ import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
const TOGGLE_DISABLED_MESSAGE =
|
||||
"Changing the retrieval source is not currently possible for this instance of Onyx.";
|
||||
|
||||
interface MigrationStatus {
|
||||
total_chunks_migrated: number;
|
||||
created_at: string | null;
|
||||
@@ -24,6 +27,7 @@ interface MigrationStatus {
|
||||
|
||||
interface RetrievalStatus {
|
||||
enable_opensearch_retrieval: boolean;
|
||||
toggling_retrieval_is_disabled?: boolean;
|
||||
}
|
||||
|
||||
function formatTimestamp(iso: string): string {
|
||||
@@ -133,6 +137,7 @@ function RetrievalSourceSection() {
|
||||
: "vespa";
|
||||
const currentValue = selectedSource ?? serverValue;
|
||||
const hasChanges = selectedSource !== null && selectedSource !== serverValue;
|
||||
const togglingDisabled = data?.toggling_retrieval_is_disabled ?? false;
|
||||
|
||||
async function handleUpdate() {
|
||||
setUpdating(true);
|
||||
@@ -188,7 +193,7 @@ function RetrievalSourceSection() {
|
||||
<InputSelect
|
||||
value={currentValue}
|
||||
onValueChange={setSelectedSource}
|
||||
disabled={updating}
|
||||
disabled={updating || togglingDisabled}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select retrieval source" />
|
||||
<InputSelect.Content>
|
||||
@@ -197,6 +202,12 @@ function RetrievalSourceSection() {
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
|
||||
{togglingDisabled && (
|
||||
<Text mainUiBody text03>
|
||||
{TOGGLE_DISABLED_MESSAGE}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{hasChanges && (
|
||||
// TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved
|
||||
<Button
|
||||
|
||||
@@ -671,7 +671,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
width="full"
|
||||
prominence="secondary"
|
||||
onClick={() => {
|
||||
window.location.href = "/admin/configuration/llm";
|
||||
window.location.href = "/admin/configuration/language-models";
|
||||
}}
|
||||
>
|
||||
Set up an LLM.
|
||||
|
||||
@@ -126,7 +126,7 @@ export const ADMIN_ROUTES = {
|
||||
sidebarLabel: "Chat Preferences",
|
||||
},
|
||||
LLM_MODELS: {
|
||||
path: "/admin/configuration/llm",
|
||||
path: "/admin/configuration/language-models",
|
||||
icon: SvgCpu,
|
||||
title: "Language Models",
|
||||
sidebarLabel: "Language Models",
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
unsetDefaultImageGenerationConfig,
|
||||
deleteImageGenerationConfig,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
|
||||
import ModelIcon from "@/app/admin/configuration/llm/ModelIcon";
|
||||
import ModelIcon from "@/app/admin/configuration/language-models/ModelIcon";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { Button, MessageCard, SelectCard, Text } from "@opal/components";
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import React, { useState, useMemo, useEffect } from "react";
|
||||
import { Form, Formik, FormikProps } from "formik";
|
||||
import ProviderModal from "@/components/modals/ProviderModal";
|
||||
import ModelIcon from "@/app/admin/configuration/llm/ModelIcon";
|
||||
import ModelIcon from "@/app/admin/configuration/language-models/ModelIcon";
|
||||
import ConnectionProviderIcon from "@/refresh-components/ConnectionProviderIcon";
|
||||
import {
|
||||
testImageGenerationApiKey,
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
SvgServer,
|
||||
SvgSettings,
|
||||
} from "@opal/icons";
|
||||
import ModelIcon from "@/app/admin/configuration/llm/ModelIcon";
|
||||
import ModelIcon from "@/app/admin/configuration/language-models/ModelIcon";
|
||||
|
||||
export interface LLMProviderCardProps {
|
||||
title: string;
|
||||
@@ -40,7 +40,7 @@ function LLMProviderCardInner({
|
||||
|
||||
if (isConnected) {
|
||||
// If connected, redirect to admin page
|
||||
window.location.href = "/admin/configuration/llm";
|
||||
window.location.href = "/admin/configuration/language-models";
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -49,7 +49,9 @@ function LLMProviderCardInner({
|
||||
}, [disabled, isConnected, onClick]);
|
||||
|
||||
const handleSettingsClick = useCallback(
|
||||
noProp(() => (window.location.href = "/admin/configuration/llm")),
|
||||
noProp(
|
||||
() => (window.location.href = "/admin/configuration/language-models")
|
||||
),
|
||||
[]
|
||||
);
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
} from "@/interfaces/llm";
|
||||
import { getProvider } from "@/lib/llmConfig";
|
||||
import { Disabled } from "@opal/core";
|
||||
import ModelIcon from "@/app/admin/configuration/llm/ModelIcon";
|
||||
import ModelIcon from "@/app/admin/configuration/language-models/ModelIcon";
|
||||
import { SvgCheckCircle, SvgCpu, SvgExternalLink } from "@opal/icons";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { useLLMProviderOptions } from "@/lib/hooks/useLLMProviderOptions";
|
||||
@@ -162,7 +162,7 @@ const LLMStep = memo(
|
||||
disabled={disabled}
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgExternalLink}
|
||||
href="/admin/configuration/llm"
|
||||
href="/admin/configuration/language-models"
|
||||
>
|
||||
View in Admin Panel
|
||||
</Button>
|
||||
|
||||
@@ -586,7 +586,11 @@ const MemoizedAppSidebarInner = memo(function AppSidebarInner() {
|
||||
<div>
|
||||
{(isAdmin || isCurator) && (
|
||||
<SidebarTab
|
||||
href={isCurator ? "/admin/agents" : "/admin/configuration/llm"}
|
||||
href={
|
||||
isCurator
|
||||
? "/admin/agents"
|
||||
: "/admin/configuration/language-models"
|
||||
}
|
||||
icon={SvgSettings}
|
||||
folded={folded}
|
||||
>
|
||||
|
||||
@@ -13,7 +13,7 @@ test.describe.configure({ mode: "parallel" });
|
||||
* user / feature-flag configuration.
|
||||
*/
|
||||
async function discoverAdminPages(page: Page): Promise<string[]> {
|
||||
await page.goto("/admin/configuration/llm");
|
||||
await page.goto("/admin/configuration/language-models");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
return page.evaluate(() => {
|
||||
|
||||
@@ -3,7 +3,7 @@ import type { Locator, Page } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
|
||||
const LLM_SETUP_URL = "/admin/configuration/llm";
|
||||
const LLM_SETUP_URL = "/admin/configuration/language-models";
|
||||
const BASE_URL = process.env.BASE_URL || "http://localhost:3000";
|
||||
const PROVIDER_API_KEY =
|
||||
process.env.E2E_LLM_PROVIDER_API_KEY ||
|
||||
@@ -120,7 +120,7 @@ async function createPublicProviderWithModels(
|
||||
|
||||
async function navigateToAdminLlmPageFromChat(page: Page): Promise<void> {
|
||||
await page.goto(LLM_SETUP_URL);
|
||||
await page.waitForURL("**/admin/configuration/llm**");
|
||||
await page.waitForURL("**/admin/configuration/language-models**");
|
||||
await expect(page.getByLabel("admin-page-title")).toHaveText(
|
||||
/^Language Models/
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user