mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-21 01:16:45 +00:00
Compare commits
1 Commits
nikg/captc
...
jamison/ti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e69f66705 |
@@ -16,13 +16,9 @@
|
||||
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
|
||||
],
|
||||
"containerEnv": {
|
||||
"MODEL_SERVER_HOST": "inference_model_server",
|
||||
"OPENSEARCH_HOST": "opensearch",
|
||||
"POSTGRES_HOST": "relational_db",
|
||||
"REDIS_HOST": "cache",
|
||||
"S3_ENDPOINT_URL": "http://minio:9000",
|
||||
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock",
|
||||
"VESPA_HOST": "index"
|
||||
"POSTGRES_HOST": "relational_db",
|
||||
"REDIS_HOST": "cache"
|
||||
},
|
||||
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
|
||||
"updateRemoteUserUID": false,
|
||||
|
||||
@@ -45,7 +45,7 @@ if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
|
||||
[ -d "$MOUNT_HOME/$item" ] || continue
|
||||
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
|
||||
rm -rf "${ACTIVE_HOME:?}/$item"
|
||||
rm -rf "$ACTIVE_HOME/$item"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
|
||||
done
|
||||
|
||||
@@ -40,7 +40,6 @@ ALLOWED_DOMAINS=(
|
||||
"api.anthropic.com"
|
||||
"api-staging.anthropic.com"
|
||||
"files.anthropic.com"
|
||||
"huggingface.co"
|
||||
"sentry.io"
|
||||
"update.code.visualstudio.com"
|
||||
"pypi.org"
|
||||
|
||||
2
.github/workflows/deployment.yml
vendored
2
.github/workflows/deployment.yml
vendored
@@ -403,7 +403,7 @@ jobs:
|
||||
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
|
||||
echo "Certificate imported."
|
||||
|
||||
- uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 # ratchet:tauri-apps/tauri-action@action-v0.6.2
|
||||
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
APPLE_ID: ${{ env.APPLE_ID }}
|
||||
|
||||
2
.github/workflows/pr-golang-tests.yml
vendored
2
.github/workflows/pr-golang-tests.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # zizmor: ignore[cache-poisoning]
|
||||
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
cache-dependency-path: "**/go.sum"
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -39,8 +39,6 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
|
||||
env:
|
||||
SKIP: ty
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
@@ -68,7 +68,6 @@ repos:
|
||||
pass_filenames: true
|
||||
files: ^backend/(?!\.venv/|scripts/).*\.py$
|
||||
- id: uv-run
|
||||
alias: ty
|
||||
name: ty
|
||||
args: ["ty", "check"]
|
||||
pass_filenames: true
|
||||
@@ -86,17 +85,6 @@ repos:
|
||||
hooks:
|
||||
- id: actionlint
|
||||
|
||||
- repo: https://github.com/shellcheck-py/shellcheck-py
|
||||
rev: 745eface02aef23e168a8afb6b5737818efbea95 # frozen: v0.11.0.1
|
||||
hooks:
|
||||
- id: shellcheck
|
||||
exclude: >-
|
||||
(?x)^(
|
||||
backend/scripts/setup_craft_templates\.sh|
|
||||
deployment/docker_compose/init-letsencrypt\.sh|
|
||||
deployment/docker_compose/install\.sh
|
||||
)$
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
|
||||
hooks:
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add file_id to documents
|
||||
|
||||
Revision ID: 91d150c361f6
|
||||
Revises: a6fcd3d631f9
|
||||
Create Date: 2026-04-16 15:43:30.314823
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91d150c361f6"
|
||||
down_revision = "a6fcd3d631f9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("file_id", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "file_id")
|
||||
@@ -1,48 +0,0 @@
|
||||
"""replace document sync index with partial index
|
||||
|
||||
Replaces the composite index ix_document_sync_status (last_modified, last_synced)
|
||||
with a partial index ix_document_needs_sync that only indexes rows where
|
||||
last_modified > last_synced OR last_synced IS NULL.
|
||||
|
||||
The old index was never used by the query planner (0 scans in pg_stat_user_indexes)
|
||||
because Postgres cannot use a B-tree composite index to evaluate a comparison
|
||||
between two columns in the same row combined with an OR/IS NULL condition.
|
||||
|
||||
The partial index makes count_documents_by_needs_sync ~4000x faster for tenants
|
||||
with no stale documents (161ms -> 0.04ms on a 929K row table) and ~17x faster
|
||||
for tenants with large backlogs (846ms -> 50ms on a 164K row table).
|
||||
|
||||
Revision ID: a6fcd3d631f9
|
||||
Revises: d129f37b3d87
|
||||
Create Date: 2026-04-17 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6fcd3d631f9"
|
||||
down_revision = "d129f37b3d87"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_document_needs_sync",
|
||||
"document",
|
||||
["id"],
|
||||
postgresql_where=sa.text("last_modified > last_synced OR last_synced IS NULL"),
|
||||
)
|
||||
op.drop_index("ix_document_sync_status", table_name="document")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_document_sync_status",
|
||||
"document",
|
||||
["last_modified", "last_synced"],
|
||||
)
|
||||
op.drop_index("ix_document_needs_sync", table_name="document")
|
||||
@@ -1,10 +1,8 @@
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
@@ -18,56 +16,9 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_tenant_work_gating import cleanup_expired
|
||||
from onyx.redis.redis_tenant_work_gating import get_active_tenants
|
||||
from onyx.redis.redis_tenant_work_gating import observe_active_set_size
|
||||
from onyx.redis.redis_tenant_work_gating import record_full_fanout_cycle
|
||||
from onyx.redis.redis_tenant_work_gating import record_gate_decision
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
|
||||
_FULL_FANOUT_TIMESTAMP_KEY_PREFIX = "tenant_work_gating_last_full_fanout_ms"
|
||||
|
||||
|
||||
def _should_bypass_gate_for_full_fanout(
|
||||
redis_client: Redis, task_name: str, interval_seconds: int
|
||||
) -> bool:
|
||||
"""True if at least `interval_seconds` have elapsed since the last
|
||||
full-fanout bypass for this task. On True, updates the stored timestamp
|
||||
atomically-enough (it's a best-effort counter, not a lock)."""
|
||||
key = f"{_FULL_FANOUT_TIMESTAMP_KEY_PREFIX}:{task_name}"
|
||||
now_ms = int(time.time() * 1000)
|
||||
threshold_ms = now_ms - (interval_seconds * 1000)
|
||||
|
||||
try:
|
||||
raw = cast(bytes | None, redis_client.get(key))
|
||||
except Exception:
|
||||
task_logger.exception(f"full-fanout timestamp read failed: task={task_name}")
|
||||
# Fail open: treat as "interval elapsed" so we don't skip every
|
||||
# tenant during a Redis hiccup.
|
||||
return True
|
||||
|
||||
if raw is None:
|
||||
# First invocation — bypass so the set seeds cleanly.
|
||||
elapsed = True
|
||||
else:
|
||||
try:
|
||||
last_ms = int(raw.decode())
|
||||
elapsed = last_ms <= threshold_ms
|
||||
except ValueError:
|
||||
elapsed = True
|
||||
|
||||
if elapsed:
|
||||
try:
|
||||
redis_client.set(key, str(now_ms))
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"full-fanout timestamp write failed: task={task_name}"
|
||||
)
|
||||
return elapsed
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
@@ -81,7 +32,6 @@ def cloud_beat_task_generator(
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
skip_gated: bool = True,
|
||||
work_gated: bool = False,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
@@ -101,56 +51,8 @@ def cloud_beat_task_generator(
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
num_skipped_gated = 0
|
||||
num_would_skip_work_gate = 0
|
||||
num_skipped_work_gate = 0
|
||||
|
||||
# Tenant-work-gating read path. Resolve once per invocation.
|
||||
gate_enabled = False
|
||||
gate_enforce = False
|
||||
full_fanout_cycle = False
|
||||
active_tenants: set[str] | None = None
|
||||
|
||||
try:
|
||||
# Gating setup is inside the try block so any exception still
|
||||
# reaches the finally that releases the beat lock.
|
||||
if work_gated:
|
||||
try:
|
||||
gate_enabled = OnyxRuntime.get_tenant_work_gating_enabled()
|
||||
gate_enforce = OnyxRuntime.get_tenant_work_gating_enforce()
|
||||
except Exception:
|
||||
task_logger.exception("tenant work gating: runtime flag read failed")
|
||||
gate_enabled = False
|
||||
|
||||
if gate_enabled:
|
||||
redis_failed = False
|
||||
interval_s = (
|
||||
OnyxRuntime.get_tenant_work_gating_full_fanout_interval_seconds()
|
||||
)
|
||||
full_fanout_cycle = _should_bypass_gate_for_full_fanout(
|
||||
redis_client, task_name, interval_s
|
||||
)
|
||||
if full_fanout_cycle:
|
||||
record_full_fanout_cycle(task_name)
|
||||
try:
|
||||
ttl_s = OnyxRuntime.get_tenant_work_gating_ttl_seconds()
|
||||
cleanup_expired(ttl_s)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"tenant work gating: cleanup_expired failed"
|
||||
)
|
||||
else:
|
||||
ttl_s = OnyxRuntime.get_tenant_work_gating_ttl_seconds()
|
||||
active_tenants = get_active_tenants(ttl_s)
|
||||
if active_tenants is None:
|
||||
full_fanout_cycle = True
|
||||
record_full_fanout_cycle(task_name)
|
||||
redis_failed = True
|
||||
|
||||
# Only refresh the gauge when Redis is known-reachable —
|
||||
# skip the ZCARD if we just failed open due to a Redis error.
|
||||
if not redis_failed:
|
||||
observe_active_set_size()
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# Per-task control over whether gated tenants are included. Most periodic tasks
|
||||
@@ -174,21 +76,6 @@ def cloud_beat_task_generator(
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
# Tenant work gate: if the feature is on, check membership. Skip
|
||||
# unmarked tenants when enforce=True AND we're not in a full-
|
||||
# fanout cycle. Always log/emit the shadow counter.
|
||||
if work_gated and gate_enabled and not full_fanout_cycle:
|
||||
would_skip = (
|
||||
active_tenants is not None and tenant_id not in active_tenants
|
||||
)
|
||||
if would_skip:
|
||||
num_would_skip_work_gate += 1
|
||||
if gate_enforce:
|
||||
num_skipped_work_gate += 1
|
||||
record_gate_decision(task_name, skipped=True)
|
||||
continue
|
||||
record_gate_decision(task_name, skipped=False)
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
@@ -222,12 +109,6 @@ def cloud_beat_task_generator(
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_skipped_gated={num_skipped_gated} "
|
||||
f"num_would_skip_work_gate={num_would_skip_work_gate} "
|
||||
f"num_skipped_work_gate={num_skipped_work_gate} "
|
||||
f"full_fanout_cycle={full_fanout_cycle} "
|
||||
f"work_gated={work_gated} "
|
||||
f"gate_enabled={gate_enabled} "
|
||||
f"gate_enforce={gate_enforce} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -212,7 +212,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set membership
|
||||
# whenever doc-permission sync has any due cc_pairs to dispatch.
|
||||
if cc_pair_ids_to_sync:
|
||||
maybe_mark_tenant_active(tenant_id, caller="doc_permission_sync")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
|
||||
@@ -206,7 +206,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set membership
|
||||
# whenever external-group sync has any due cc_pairs to dispatch.
|
||||
if cc_pair_ids_to_sync:
|
||||
maybe_mark_tenant_active(tenant_id, caller="external_group_sync")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
@@ -506,18 +506,6 @@ def _perform_external_group_sync(
|
||||
|
||||
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
|
||||
|
||||
# Clean up stale rows from previous cycle BEFORE marking new ones.
|
||||
# This ensures cleanup always runs regardless of whether the current
|
||||
# sync succeeds — previously, cleanup only ran at the END of the sync,
|
||||
# so if the sync failed (e.g. DB connection killed by
|
||||
# idle_in_transaction_session_timeout during long API calls), stale
|
||||
# rows would accumulate indefinitely.
|
||||
logger.info(
|
||||
f"Removing stale external groups from prior cycle for {source_type} "
|
||||
f"for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
remove_stale_external_groups(db_session, cc_pair_id)
|
||||
|
||||
logger.info(
|
||||
f"Marking old external groups as stale for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
@@ -78,15 +77,6 @@ def mark_old_external_groups_as_stale(
|
||||
.where(PublicExternalUserGroup.cc_pair_id == cc_pair_id)
|
||||
.values(stale=True)
|
||||
)
|
||||
# Commit immediately so the transaction closes before potentially long
|
||||
# external API calls (e.g. Google Drive folder iteration). Without this,
|
||||
# the DB connection sits idle-in-transaction during API calls and gets
|
||||
# killed by idle_in_transaction_session_timeout, causing the entire sync
|
||||
# to fail and stale cleanup to never run.
|
||||
db_session.commit()
|
||||
|
||||
|
||||
_UPSERT_BATCH_SIZE = 5000
|
||||
|
||||
|
||||
def upsert_external_groups(
|
||||
@@ -96,102 +86,91 @@ def upsert_external_groups(
|
||||
source: DocumentSource,
|
||||
) -> None:
|
||||
"""
|
||||
Batch upsert external user groups using INSERT ... ON CONFLICT DO UPDATE.
|
||||
- For existing rows (same user_id, external_user_group_id, cc_pair_id),
|
||||
sets stale=False
|
||||
- For new rows, inserts with stale=False
|
||||
- Same logic for PublicExternalUserGroup
|
||||
Performs a true upsert operation for external user groups:
|
||||
- For existing groups (same user_id, external_user_group_id, cc_pair_id), updates the stale flag to False
|
||||
- For new groups, inserts them with stale=False
|
||||
- For public groups, uses upsert logic as well
|
||||
"""
|
||||
# If there are no groups to add, return early
|
||||
if not external_groups:
|
||||
return
|
||||
|
||||
# Collect all emails from all groups to batch-add users at once
|
||||
all_group_member_emails: set[str] = set()
|
||||
# collect all emails from all groups to batch add all users at once for efficiency
|
||||
all_group_member_emails = set()
|
||||
for external_group in external_groups:
|
||||
all_group_member_emails.update(external_group.user_emails)
|
||||
for user_email in external_group.user_emails:
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
# Batch add users if they don't exist and get their ids
|
||||
# batch add users if they don't exist and get their ids
|
||||
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
# NOTE: this function handles case sensitivity for emails
|
||||
emails=list(all_group_member_emails),
|
||||
)
|
||||
|
||||
# map emails to ids
|
||||
email_id_map = {user.email.lower(): user.id for user in all_group_members}
|
||||
|
||||
# Build all user-group mappings and public-group mappings
|
||||
user_group_mappings: list[dict] = []
|
||||
public_group_mappings: list[dict] = []
|
||||
|
||||
# Process each external group
|
||||
for external_group in external_groups:
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
|
||||
# Handle user-group mappings
|
||||
for user_email in external_group.user_emails:
|
||||
user_id = email_id_map.get(user_email.lower())
|
||||
if user_id is None:
|
||||
logger.warning(
|
||||
f"User in group {external_group.id}"
|
||||
f" with email {user_email} not found"
|
||||
f"User in group {external_group.id} with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
|
||||
user_group_mappings.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"external_user_group_id": external_group_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"stale": False,
|
||||
}
|
||||
# Check if the user-group mapping already exists
|
||||
existing_user_group = db_session.scalar(
|
||||
select(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id,
|
||||
User__ExternalUserGroupId.external_user_group_id
|
||||
== external_group_id,
|
||||
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
if existing_user_group:
|
||||
# Update existing record
|
||||
existing_user_group.stale = False
|
||||
else:
|
||||
# Insert new record
|
||||
new_user_group = User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
stale=False,
|
||||
)
|
||||
db_session.add(new_user_group)
|
||||
|
||||
# Handle public group if needed
|
||||
if external_group.gives_anyone_access:
|
||||
public_group_mappings.append(
|
||||
{
|
||||
"external_user_group_id": external_group_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"stale": False,
|
||||
}
|
||||
# Check if the public group already exists
|
||||
existing_public_group = db_session.scalar(
|
||||
select(PublicExternalUserGroup).where(
|
||||
PublicExternalUserGroup.external_user_group_id == external_group_id,
|
||||
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Deduplicate to avoid "ON CONFLICT DO UPDATE command cannot affect row
|
||||
# a second time" when duplicate emails or overlapping groups produce
|
||||
# identical (user_id, external_user_group_id, cc_pair_id) tuples.
|
||||
user_group_mappings_deduped = list(
|
||||
{
|
||||
(m["user_id"], m["external_user_group_id"], m["cc_pair_id"]): m
|
||||
for m in user_group_mappings
|
||||
}.values()
|
||||
)
|
||||
|
||||
# Batch upsert user-group mappings
|
||||
for i in range(0, len(user_group_mappings_deduped), _UPSERT_BATCH_SIZE):
|
||||
chunk = user_group_mappings_deduped[i : i + _UPSERT_BATCH_SIZE]
|
||||
stmt = pg_insert(User__ExternalUserGroupId).values(chunk)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["user_id", "external_user_group_id", "cc_pair_id"],
|
||||
set_={"stale": False},
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
# Deduplicate public group mappings as well
|
||||
public_group_mappings_deduped = list(
|
||||
{
|
||||
(m["external_user_group_id"], m["cc_pair_id"]): m
|
||||
for m in public_group_mappings
|
||||
}.values()
|
||||
)
|
||||
|
||||
# Batch upsert public group mappings
|
||||
for i in range(0, len(public_group_mappings_deduped), _UPSERT_BATCH_SIZE):
|
||||
chunk = public_group_mappings_deduped[i : i + _UPSERT_BATCH_SIZE]
|
||||
stmt = pg_insert(PublicExternalUserGroup).values(chunk)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["external_user_group_id", "cc_pair_id"],
|
||||
set_={"stale": False},
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
if existing_public_group:
|
||||
# Update existing record
|
||||
existing_public_group.stale = False
|
||||
else:
|
||||
# Insert new record
|
||||
new_public_group = PublicExternalUserGroup(
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
stale=False,
|
||||
)
|
||||
db_session.add(new_public_group)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.configs import SENTRY_TRACES_SAMPLE_RATE
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
@@ -102,7 +101,7 @@ def get_model_app() -> FastAPI:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=SENTRY_TRACES_SAMPLE_RATE,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
|
||||
@@ -76,34 +76,24 @@ async def verify_captcha_token(
|
||||
f"Captcha verification failed: {', '.join(error_codes)}"
|
||||
)
|
||||
|
||||
# Require v3 score. Google's public test secret returns no score
|
||||
# — that path must not be active in prod since it skips the only
|
||||
# human-vs-bot signal. A missing score here means captcha is
|
||||
# misconfigured (test secret in prod, or a v2 response slipped in
|
||||
# via an action mismatch).
|
||||
if result.score is None:
|
||||
logger.warning(
|
||||
"Captcha verification failed: siteverify returned no score (likely test secret in prod)"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: missing score"
|
||||
)
|
||||
# For reCAPTCHA v3, also check the score
|
||||
if result.score is not None:
|
||||
if result.score < RECAPTCHA_SCORE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: suspicious activity detected"
|
||||
)
|
||||
|
||||
if result.score < RECAPTCHA_SCORE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: suspicious activity detected"
|
||||
)
|
||||
|
||||
if result.action and result.action != expected_action:
|
||||
logger.warning(
|
||||
f"Captcha action mismatch: {result.action} != {expected_action}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: action mismatch"
|
||||
)
|
||||
# Optionally verify the action matches
|
||||
if result.action and result.action != expected_action:
|
||||
logger.warning(
|
||||
f"Captcha action mismatch: {result.action} != {expected_action}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: action mismatch"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Captcha verification passed: score={result.score}, action={result.action}"
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFI
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
@@ -55,7 +54,6 @@ from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import DEV_LOGGING_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SENTRY_CELERY_TRACES_SAMPLE_RATE
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -70,7 +68,7 @@ if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=SENTRY_CELERY_TRACES_SAMPLE_RATE,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
@@ -533,26 +531,23 @@ def reset_tenant_id(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
|
||||
def wait_for_document_index_or_shutdown() -> None:
|
||||
"""
|
||||
Waits for all configured document indices to become ready subject to a
|
||||
timeout.
|
||||
def wait_for_vespa_or_shutdown(
|
||||
sender: Any, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
Raises WorkerShutdown if the timeout is reached.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
logger.info(
|
||||
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
|
||||
)
|
||||
return
|
||||
|
||||
if not ONYX_DISABLE_VESPA:
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = (
|
||||
"[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
|
||||
@@ -105,7 +105,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -111,7 +111,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -97,7 +97,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -118,7 +118,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -124,7 +124,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.configs.app_configs import DISABLE_OPENSEARCH_MIGRATION_TASK
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
@@ -68,7 +67,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -102,7 +100,6 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Gated tenants may still have connectors awaiting deletion.
|
||||
"skip_gated": False,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -112,7 +109,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -122,7 +118,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -160,7 +155,6 @@ beat_task_templates: list[dict] = [
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -185,7 +179,6 @@ if ENTERPRISE_EDITION_ENABLED:
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -195,7 +188,6 @@ if ENTERPRISE_EDITION_ENABLED:
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
]
|
||||
@@ -235,11 +227,7 @@ if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
)
|
||||
|
||||
# Add OpenSearch migration task if enabled.
|
||||
if (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and not DISABLE_OPENSEARCH_MIGRATION_TASK
|
||||
and not ONYX_DISABLE_VESPA
|
||||
):
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX and not DISABLE_OPENSEARCH_MIGRATION_TASK:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "migrate-chunks-from-vespa-to-opensearch",
|
||||
@@ -292,7 +280,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated", "work_gated"]
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
@@ -385,14 +373,12 @@ if not MULTI_TENANT:
|
||||
]
|
||||
)
|
||||
|
||||
# `skip_gated` and `work_gated` are cloud-only hints consumed by
|
||||
# `cloud_beat_task_generator`. Strip them before extending the self-hosted
|
||||
# schedule so they don't leak into apply_async as unrecognised options on
|
||||
# every fired task message.
|
||||
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
|
||||
# it before extending the self-hosted schedule so it doesn't leak into apply_async
|
||||
# as an unrecognised option on every fired task message.
|
||||
for _template in beat_task_templates:
|
||||
_self_hosted_template = copy.deepcopy(_template)
|
||||
_self_hosted_template["options"].pop("skip_gated", None)
|
||||
_self_hosted_template["options"].pop("work_gated", None)
|
||||
tasks_to_schedule.append(_self_hosted_template)
|
||||
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
# nearly every tenant in the active set since most have cc_pairs
|
||||
# but almost none are actively being deleted on any given cycle.
|
||||
if has_deleting_cc_pair:
|
||||
maybe_mark_tenant_active(tenant_id, caller="connector_deletion")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
|
||||
@@ -37,7 +37,6 @@ from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import SENTRY_CELERY_TRACES_SAMPLE_RATE
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -141,7 +140,7 @@ def _docfetching_task(
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=SENTRY_CELERY_TRACES_SAMPLE_RATE,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
|
||||
@@ -1020,7 +1020,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
# `tasks_created > 0` here gives us a "real work was done" signal
|
||||
# rather than just "tenant has a cc_pair somewhere."
|
||||
if tasks_created > 0:
|
||||
maybe_mark_tenant_active(tenant_id, caller="check_for_indexing")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# 2/3: VALIDATE
|
||||
# Check for inconsistent index attempts - active attempts without task IDs
|
||||
|
||||
@@ -263,7 +263,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# since most tenants have cc_pairs but almost none are due on
|
||||
# any given cycle.
|
||||
if prune_dispatched:
|
||||
maybe_mark_tenant_active(tenant_id, caller="check_for_pruning")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=_get_pruning_block_expiration())
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
|
||||
@@ -153,7 +153,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set membership
|
||||
# whenever vespa sync actually has stale docs to dispatch.
|
||||
maybe_mark_tenant_active(tenant_id, caller="vespa_sync")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch."
|
||||
|
||||
@@ -58,8 +58,6 @@ from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.file_store.staging import build_raw_file_callback
|
||||
from onyx.file_store.staging import RawFileCallback
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
@@ -92,7 +90,6 @@ def _get_connector_runner(
|
||||
end_time: datetime,
|
||||
include_permissions: bool,
|
||||
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
|
||||
raw_file_callback: RawFileCallback | None = None,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
@@ -111,7 +108,6 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
raw_file_callback=raw_file_callback,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
@@ -279,12 +275,6 @@ def run_docfetching_entrypoint(
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
raw_file_callback = build_raw_file_callback(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=connector_credential_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
connector_document_extraction(
|
||||
app,
|
||||
index_attempt_id,
|
||||
@@ -292,7 +282,6 @@ def run_docfetching_entrypoint(
|
||||
attempt.search_settings_id,
|
||||
tenant_id,
|
||||
callback,
|
||||
raw_file_callback=raw_file_callback,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -312,7 +301,6 @@ def connector_document_extraction(
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
raw_file_callback: RawFileCallback | None = None,
|
||||
) -> None:
|
||||
"""Extract documents from connector and queue them for indexing pipeline processing.
|
||||
|
||||
@@ -463,7 +451,6 @@ def connector_document_extraction(
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
include_permissions=should_fetch_permissions_during_indexing,
|
||||
raw_file_callback=raw_file_callback,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
|
||||
@@ -60,9 +60,7 @@ from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
@@ -76,17 +74,12 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
@@ -100,6 +93,7 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.multi_llm import LLMTimeoutError
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
@@ -129,7 +123,6 @@ from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -252,106 +245,22 @@ def _empty_extracted_context_files() -> ExtractedContextFiles:
|
||||
)
|
||||
|
||||
|
||||
def _fetch_cached_image_captions(
|
||||
user_file: UserFile | None,
|
||||
document_index: DocumentIndex | None,
|
||||
) -> list[str]:
|
||||
"""Read image-caption chunks for a user file from the document index.
|
||||
|
||||
During indexing, embedded images are summarized via a vision LLM and
|
||||
those summaries are stored as chunks whose `image_file_id` is set. Reading
|
||||
them back at chat time avoids re-running vision-LLM calls per turn.
|
||||
Returns an empty list if the index has no chunks yet (e.g. indexing is
|
||||
still in flight) or on any fetch failure.
|
||||
"""
|
||||
if user_file is None or document_index is None:
|
||||
return []
|
||||
try:
|
||||
chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=[VespaChunkRequest(document_id=str(user_file.id))],
|
||||
filters=IndexFilters(
|
||||
access_control_list=None,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to fetch cached captions for user_file {user_file.id}",
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
|
||||
# An image can be spread across multiple chunks; combine by image_file_id
|
||||
# so a single caption appears once in the context.
|
||||
combined: dict[str, list[str]] = {}
|
||||
for chunk in chunks:
|
||||
if chunk.image_file_id and chunk.content:
|
||||
combined.setdefault(chunk.image_file_id, []).append(chunk.content)
|
||||
return [
|
||||
f"[Image — {image_file_id}]\n" + "\n".join(contents)
|
||||
for image_file_id, contents in combined.items()
|
||||
]
|
||||
|
||||
|
||||
def _extract_text_from_in_memory_file(
|
||||
f: InMemoryChatFile,
|
||||
user_file: UserFile | None = None,
|
||||
document_index: DocumentIndex | None = None,
|
||||
) -> str | None:
|
||||
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
|
||||
"""Extract text content from an InMemoryChatFile.
|
||||
|
||||
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
|
||||
ingestion — decode directly.
|
||||
DOC / CSV / other text types: the content is the original file bytes —
|
||||
use extract_file_text which handles encoding detection and format parsing.
|
||||
When image extraction is enabled and the file has embedded images, cached
|
||||
captions are pulled from the document index and appended to the text.
|
||||
The index fetch is skipped for files with no embedded images. We do not
|
||||
re-summarize images inline here — this path is hot and the indexing
|
||||
pipeline writes chunks atomically, so a missed caption means the file
|
||||
is mid-indexing and will be picked up on the next turn.
|
||||
"""
|
||||
try:
|
||||
if f.file_type == ChatFileType.PLAIN_TEXT:
|
||||
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
|
||||
|
||||
filename = f.filename or ""
|
||||
if not get_image_extraction_and_analysis_enabled():
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=filename,
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
extraction = extract_text_and_images(
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=filename,
|
||||
file_name=f.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
text = extraction.text_content
|
||||
has_text = bool(text.strip())
|
||||
has_images = bool(extraction.embedded_images)
|
||||
|
||||
if not has_text and not has_images:
|
||||
# extract_text_and_images has no is_text_file() fallback for
|
||||
# unknown extensions (.py/.rs/.md without a dedicated handler).
|
||||
# Defer to the legacy path so those files remain readable.
|
||||
return extract_file_text(
|
||||
file=io.BytesIO(f.content),
|
||||
file_name=filename,
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
if not has_images:
|
||||
return text if has_text else None
|
||||
|
||||
cached_captions = _fetch_cached_image_captions(user_file, document_index)
|
||||
|
||||
parts: list[str] = []
|
||||
if has_text:
|
||||
parts.append(text)
|
||||
parts.extend(cached_captions)
|
||||
|
||||
return "\n\n".join(parts).strip() or None
|
||||
except Exception:
|
||||
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
|
||||
return None
|
||||
@@ -433,23 +342,6 @@ def extract_context_files(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# The document index is used at chat time to read cached image captions
|
||||
# (produced during indexing) so vision-LLM calls don't re-run per turn.
|
||||
document_index: DocumentIndex | None = None
|
||||
if not DISABLE_VECTOR_DB and get_image_extraction_and_analysis_enabled():
|
||||
try:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to construct document index for caption lookup",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
file_texts: list[str] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
file_metadata: list[ContextFileMetadata] = []
|
||||
@@ -470,9 +362,7 @@ def extract_context_files(
|
||||
continue
|
||||
tool_metadata.append(_build_tool_metadata(uf))
|
||||
elif f.file_type.is_text_file():
|
||||
text_content = _extract_text_from_in_memory_file(
|
||||
f, user_file=uf, document_index=document_index
|
||||
)
|
||||
text_content = _extract_text_from_in_memory_file(f)
|
||||
if not text_content:
|
||||
continue
|
||||
if not uf:
|
||||
@@ -1277,6 +1167,32 @@ def _run_models(
|
||||
else:
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
elif isinstance(item, LLMTimeoutError):
|
||||
model_llm = setup.llms[model_idx]
|
||||
error_msg = (
|
||||
"The LLM took too long to respond. "
|
||||
"If you're running a local model, try increasing the "
|
||||
"LLM_SOCKET_READ_TIMEOUT environment variable "
|
||||
"(current default: 120 seconds)."
|
||||
)
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(type(item), item, item.__traceback__)
|
||||
)
|
||||
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
|
||||
stack_trace = stack_trace.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="CONNECTION_ERROR",
|
||||
is_retryable=True,
|
||||
details={
|
||||
"model": model_llm.config.model_name,
|
||||
"provider": model_llm.config.model_provider,
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
elif isinstance(item, Exception):
|
||||
# Yield a tagged error for this model but keep the other models running.
|
||||
# Do NOT decrement models_remaining — _run_model's finally always posts
|
||||
|
||||
@@ -282,7 +282,6 @@ OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
|
||||
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
|
||||
)
|
||||
OPENSEARCH_USE_SSL = os.environ.get("OPENSEARCH_USE_SSL", "true").lower() == "true"
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
@@ -328,7 +327,6 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
DISABLE_OPENSEARCH_MIGRATION_TASK = (
|
||||
os.environ.get("DISABLE_OPENSEARCH_MIGRATION_TASK", "").lower() == "true"
|
||||
)
|
||||
ONYX_DISABLE_VESPA = os.environ.get("ONYX_DISABLE_VESPA", "").lower() == "true"
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
@@ -868,15 +866,6 @@ MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
|
||||
)
|
||||
|
||||
# Maximum non-empty cells to extract from a single xlsx worksheet. Protects
|
||||
# from OOM on honestly-huge spreadsheets: memory cost in the extractor is
|
||||
# roughly proportional to this count. Once exceeded, the scan stops and a
|
||||
# truncation marker row is appended to the sheet's CSV.
|
||||
# Peak Memory ~= 100 B * MAX_CELLS
|
||||
MAX_XLSX_CELLS_PER_SHEET = max(
|
||||
0, int(os.environ.get("MAX_XLSX_CELLS_PER_SHEET") or 10_000_000)
|
||||
)
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
|
||||
@@ -372,7 +372,6 @@ class FileOrigin(str, Enum):
|
||||
CONNECTOR_METADATA = "connector_metadata"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
INDEXING_CHECKPOINT = "indexing_checkpoint"
|
||||
INDEXING_STAGING = "indexing_staging"
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
OTHER = "other"
|
||||
QUERY_HISTORY_CSV = "query_history_csv"
|
||||
|
||||
@@ -3,14 +3,7 @@ from onyx.server.settings.store import load_settings
|
||||
|
||||
|
||||
def get_image_extraction_and_analysis_enabled() -> bool:
|
||||
"""Return the workspace setting for image extraction/analysis.
|
||||
|
||||
The pydantic `Settings` model defaults this field to True, so production
|
||||
tenants get the feature on by default on first read. The fallback here
|
||||
stays False so environments where settings cannot be loaded at all
|
||||
(e.g. unit tests with no DB/Redis) don't trigger downstream vision-LLM
|
||||
code paths that assume the DB is reachable.
|
||||
"""
|
||||
"""Get image extraction and analysis enabled setting from workspace settings or fallback to False"""
|
||||
try:
|
||||
settings = load_settings()
|
||||
if settings.image_extraction_and_analysis_enabled is not None:
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import Credential
|
||||
from onyx.file_store.staging import RawFileCallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -108,7 +107,6 @@ def instantiate_connector(
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
raw_file_callback: RawFileCallback | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
@@ -132,9 +130,6 @@ def instantiate_connector(
|
||||
|
||||
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
|
||||
|
||||
if raw_file_callback is not None:
|
||||
connector.set_raw_file_callback(raw_file_callback)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -40,22 +40,6 @@ class GongConnectorCheckpoint(ConnectorCheckpoint):
|
||||
cursor: str | None = None
|
||||
# Cached time range — computed once, reused across checkpoint calls
|
||||
time_range: tuple[str, str] | None = None
|
||||
# Transcripts whose call details were not yet available from /v2/calls/extensive
|
||||
# (Gong has a known race where transcript call IDs take time to propagate).
|
||||
# Keyed by call_id. Retried on subsequent checkpoint invocations.
|
||||
#
|
||||
# Invariant: all entries share one resolution session — they're stashed
|
||||
# together from a single page and share the attempt counter and retry
|
||||
# deadline. load_from_checkpoint only fetches a new page when this dict
|
||||
# is empty, so entries from different pages can't mix.
|
||||
pending_transcripts: dict[str, dict[str, Any]] = {}
|
||||
# Number of resolution attempts made for pending_transcripts so far.
|
||||
pending_call_details_attempts: int = 0
|
||||
# Unix timestamp before which we should not retry pending_transcripts.
|
||||
# Enforces exponential backoff independent of worker cadence — Gong's
|
||||
# transcript-ID propagation race can take tens of seconds to minutes,
|
||||
# longer than typical worker reinvocation intervals.
|
||||
pending_retry_after: float | None = None
|
||||
|
||||
|
||||
class _TranscriptPage(BaseModel):
|
||||
@@ -78,15 +62,8 @@ class _CursorExpiredError(Exception):
|
||||
|
||||
class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
# Max number of attempts to resolve missing call details across checkpoint
|
||||
# invocations before giving up and emitting ConnectorFailure.
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
# Base delay for exponential backoff between pending-transcript retry
|
||||
# attempts. Delay before attempt N (N >= 2) is CALL_DETAILS_DELAY * 2^(N-2)
|
||||
# seconds (30, 60, 120, 240, 480 = ~15.5min total) — matching the original
|
||||
# blocking-retry schedule, but enforced via checkpoint deadline rather
|
||||
# than in-call time.sleep.
|
||||
CALL_DETAILS_DELAY = 30
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
@@ -210,6 +187,50 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
|
||||
|
||||
return call_to_metadata
|
||||
|
||||
def _fetch_call_details_with_retry(self, call_ids: list[str]) -> dict[str, Any]:
|
||||
"""Fetch call details with retry for the Gong API race condition.
|
||||
|
||||
The Gong API has a known race where transcript call IDs don't immediately
|
||||
appear in /v2/calls/extensive. Retries with exponential backoff, only
|
||||
re-requesting the missing IDs on each attempt.
|
||||
"""
|
||||
call_details_map = self._get_call_details_by_ids(call_ids)
|
||||
if set(call_ids) == set(call_details_map.keys()):
|
||||
return call_details_map
|
||||
|
||||
for attempt in range(2, self.MAX_CALL_DETAILS_ATTEMPTS + 1):
|
||||
missing_ids = list(set(call_ids) - set(call_details_map.keys()))
|
||||
logger.warning(
|
||||
f"_get_call_details_by_ids is missing call id's: current_attempt={attempt - 1} missing_call_ids={missing_ids}"
|
||||
)
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, attempt - 2)
|
||||
logger.warning(
|
||||
f"_get_call_details_by_ids waiting to retry: "
|
||||
f"wait={wait_seconds}s "
|
||||
f"current_attempt={attempt - 1} "
|
||||
f"next_attempt={attempt} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
)
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
# Only re-fetch the missing IDs, merge into existing results
|
||||
new_details = self._get_call_details_by_ids(missing_ids)
|
||||
call_details_map.update(new_details)
|
||||
|
||||
if set(call_ids) == set(call_details_map.keys()):
|
||||
return call_details_map
|
||||
|
||||
missing_ids = list(set(call_ids) - set(call_details_map.keys()))
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(call_ids)} calls"
|
||||
)
|
||||
return call_details_map
|
||||
|
||||
@staticmethod
|
||||
def _parse_parties(parties: list[dict]) -> dict[str, str]:
|
||||
id_mapping = {}
|
||||
@@ -292,119 +313,87 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
|
||||
|
||||
return start_time, end_time
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
transcript: dict[str, Any],
|
||||
call_details: dict[str, Any],
|
||||
) -> Document:
|
||||
"""Build a single Document from a transcript and its resolved call details."""
|
||||
call_id = transcript["callId"]
|
||||
call_metadata = call_details["metaData"]
|
||||
|
||||
call_time_str = call_metadata["started"]
|
||||
call_title = call_metadata["title"]
|
||||
logger.info(
|
||||
f"Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}"
|
||||
)
|
||||
|
||||
call_parties = cast(list[dict] | None, call_details.get("parties"))
|
||||
if call_parties is None:
|
||||
logger.error(f"Couldn't get parties for Call ID: {call_id}")
|
||||
call_parties = []
|
||||
|
||||
id_to_name_map = self._parse_parties(call_parties)
|
||||
|
||||
speaker_to_name: dict[str, str] = {}
|
||||
|
||||
transcript_text = ""
|
||||
call_purpose = call_metadata["purpose"]
|
||||
if call_purpose:
|
||||
transcript_text += f"Call Description: {call_purpose}\n\n"
|
||||
|
||||
contents = transcript["transcript"]
|
||||
for segment in contents:
|
||||
speaker_id = segment.get("speakerId", "")
|
||||
if speaker_id not in speaker_to_name:
|
||||
if self.hide_user_info:
|
||||
speaker_to_name[speaker_id] = f"User {len(speaker_to_name) + 1}"
|
||||
else:
|
||||
speaker_to_name[speaker_id] = id_to_name_map.get(
|
||||
speaker_id, "Unknown"
|
||||
)
|
||||
|
||||
speaker_name = speaker_to_name[speaker_id]
|
||||
|
||||
sentences = segment.get("sentences", {})
|
||||
monolog = " ".join([sentence.get("text", "") for sentence in sentences])
|
||||
transcript_text += f"{speaker_name}: {monolog}\n\n"
|
||||
|
||||
return Document(
|
||||
id=call_id,
|
||||
sections=[TextSection(link=call_metadata["url"], text=transcript_text)],
|
||||
source=DocumentSource.GONG,
|
||||
semantic_identifier=call_title or "Untitled",
|
||||
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={"client": call_metadata.get("system")},
|
||||
)
|
||||
|
||||
def _process_transcripts(
|
||||
self,
|
||||
transcripts: list[dict[str, Any]],
|
||||
checkpoint: GongConnectorCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Fetch call details for a page of transcripts and yield resulting
|
||||
Documents. Transcripts whose call details are missing (Gong race
|
||||
condition) are stashed into `checkpoint.pending_transcripts` for retry
|
||||
on a future checkpoint invocation rather than blocking here.
|
||||
"""
|
||||
"""Process a batch of transcripts into Documents or ConnectorFailures."""
|
||||
transcript_call_ids = cast(
|
||||
list[str],
|
||||
[t.get("callId") for t in transcripts if t.get("callId")],
|
||||
)
|
||||
|
||||
call_details_map = (
|
||||
self._get_call_details_by_ids(transcript_call_ids)
|
||||
if transcript_call_ids
|
||||
else {}
|
||||
)
|
||||
|
||||
newly_stashed: list[str] = []
|
||||
call_details_map = self._fetch_call_details_with_retry(transcript_call_ids)
|
||||
|
||||
for transcript in transcripts:
|
||||
call_id = transcript.get("callId")
|
||||
|
||||
if not call_id:
|
||||
logger.error(
|
||||
"Couldn't get call information for transcript missing callId"
|
||||
)
|
||||
if not call_id or call_id not in call_details_map:
|
||||
logger.error(f"Couldn't get call information for Call ID: {call_id}")
|
||||
if call_id:
|
||||
logger.error(
|
||||
f"Call debug info: call_id={call_id} "
|
||||
f"call_ids={transcript_call_ids} "
|
||||
f"call_details_map={call_details_map.keys()}"
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(document_id="unknown"),
|
||||
failure_message="Transcript missing callId",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=call_id or "unknown",
|
||||
),
|
||||
failure_message=f"Couldn't get call information for Call ID: {call_id}",
|
||||
)
|
||||
continue
|
||||
|
||||
if call_id in call_details_map:
|
||||
yield self._build_document(transcript, call_details_map[call_id])
|
||||
continue
|
||||
call_details = call_details_map[call_id]
|
||||
call_metadata = call_details["metaData"]
|
||||
|
||||
# Details not available yet — stash for retry on next invocation.
|
||||
checkpoint.pending_transcripts[call_id] = transcript
|
||||
newly_stashed.append(call_id)
|
||||
|
||||
if newly_stashed:
|
||||
logger.warning(
|
||||
f"Gong call details not yet available (race condition); "
|
||||
f"deferring to next checkpoint invocation: "
|
||||
f"call_ids={newly_stashed}"
|
||||
call_time_str = call_metadata["started"]
|
||||
call_title = call_metadata["title"]
|
||||
logger.info(
|
||||
f"Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}"
|
||||
)
|
||||
|
||||
call_parties = cast(list[dict] | None, call_details.get("parties"))
|
||||
if call_parties is None:
|
||||
logger.error(f"Couldn't get parties for Call ID: {call_id}")
|
||||
call_parties = []
|
||||
|
||||
id_to_name_map = self._parse_parties(call_parties)
|
||||
|
||||
speaker_to_name: dict[str, str] = {}
|
||||
|
||||
transcript_text = ""
|
||||
call_purpose = call_metadata["purpose"]
|
||||
if call_purpose:
|
||||
transcript_text += f"Call Description: {call_purpose}\n\n"
|
||||
|
||||
contents = transcript["transcript"]
|
||||
for segment in contents:
|
||||
speaker_id = segment.get("speakerId", "")
|
||||
if speaker_id not in speaker_to_name:
|
||||
if self.hide_user_info:
|
||||
speaker_to_name[speaker_id] = f"User {len(speaker_to_name) + 1}"
|
||||
else:
|
||||
speaker_to_name[speaker_id] = id_to_name_map.get(
|
||||
speaker_id, "Unknown"
|
||||
)
|
||||
|
||||
speaker_name = speaker_to_name[speaker_id]
|
||||
|
||||
sentences = segment.get("sentences", {})
|
||||
monolog = " ".join([sentence.get("text", "") for sentence in sentences])
|
||||
transcript_text += f"{speaker_name}: {monolog}\n\n"
|
||||
|
||||
yield Document(
|
||||
id=call_id,
|
||||
sections=[TextSection(link=call_metadata["url"], text=transcript_text)],
|
||||
source=DocumentSource.GONG,
|
||||
semantic_identifier=call_title or "Untitled",
|
||||
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={"client": call_metadata.get("system")},
|
||||
)
|
||||
# First attempt on any newly-stashed transcripts counts as attempt #1.
|
||||
# pending_call_details_attempts is guaranteed 0 here because
|
||||
# load_from_checkpoint only reaches _process_transcripts when
|
||||
# pending_transcripts was empty at entry (see early-return above).
|
||||
checkpoint.pending_call_details_attempts = 1
|
||||
checkpoint.pending_retry_after = time.time() + self._next_retry_delay(1)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
combined = (
|
||||
@@ -443,18 +432,6 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
|
||||
checkpoint.has_more = True
|
||||
return checkpoint
|
||||
|
||||
# Step 2: Resolve any transcripts stashed by a prior invocation whose
|
||||
# call details were missing due to Gong's propagation race. Worker
|
||||
# cadence between checkpoint calls provides the spacing between retry
|
||||
# attempts — no in-call sleep needed.
|
||||
if checkpoint.pending_transcripts:
|
||||
yield from self._resolve_pending_transcripts(checkpoint)
|
||||
# If pending still exists and we haven't exhausted attempts, defer
|
||||
# the rest of this invocation — _resolve_pending_transcripts set
|
||||
# has_more=True for us.
|
||||
if checkpoint.pending_transcripts:
|
||||
return checkpoint
|
||||
|
||||
workspace_ids = checkpoint.workspace_ids
|
||||
|
||||
# If we've exhausted all workspaces, we're done
|
||||
@@ -473,7 +450,7 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
|
||||
|
||||
workspace_id = workspace_ids[checkpoint.workspace_index]
|
||||
|
||||
# Step 3: Fetch one page of transcripts
|
||||
# Step 2: Fetch one page of transcripts
|
||||
try:
|
||||
page = self._fetch_transcript_page(
|
||||
start_datetime=start_time,
|
||||
@@ -496,102 +473,23 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
|
||||
checkpoint.has_more = True
|
||||
return checkpoint
|
||||
|
||||
# Step 4: Process transcripts into documents. Missing-details
|
||||
# transcripts get stashed into checkpoint.pending_transcripts.
|
||||
# Step 3: Process transcripts into documents
|
||||
if page.transcripts:
|
||||
yield from self._process_transcripts(page.transcripts, checkpoint)
|
||||
yield from self._process_transcripts(page.transcripts)
|
||||
|
||||
# Step 5: Update cursor/workspace state
|
||||
# Step 4: Update checkpoint state
|
||||
if page.next_cursor:
|
||||
# More pages in this workspace
|
||||
checkpoint.cursor = page.next_cursor
|
||||
checkpoint.has_more = True
|
||||
else:
|
||||
# This workspace is exhausted — advance to next
|
||||
checkpoint.workspace_index += 1
|
||||
checkpoint.cursor = None
|
||||
checkpoint.has_more = checkpoint.workspace_index < len(workspace_ids)
|
||||
|
||||
# If pending transcripts were stashed this invocation, we still have
|
||||
# work to do on a future invocation even if pagination is exhausted.
|
||||
if checkpoint.pending_transcripts:
|
||||
checkpoint.has_more = True
|
||||
|
||||
return checkpoint
|
||||
|
||||
def _next_retry_delay(self, attempts_done: int) -> float:
|
||||
"""Seconds to wait before attempt #(attempts_done + 1).
|
||||
Matches the original exponential backoff: 30, 60, 120, 240, 480.
|
||||
"""
|
||||
return self.CALL_DETAILS_DELAY * pow(2, attempts_done - 1)
|
||||
|
||||
def _resolve_pending_transcripts(
|
||||
self,
|
||||
checkpoint: GongConnectorCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Attempt to resolve transcripts whose call details were unavailable
|
||||
in a prior invocation. Mutates checkpoint in place: resolved transcripts
|
||||
are removed from pending_transcripts; on attempt exhaustion, emits
|
||||
ConnectorFailure for each unresolved call_id and clears pending state.
|
||||
|
||||
If the backoff deadline hasn't elapsed yet, returns without issuing
|
||||
any API call so the next invocation can try again later.
|
||||
"""
|
||||
if (
|
||||
checkpoint.pending_retry_after is not None
|
||||
and time.time() < checkpoint.pending_retry_after
|
||||
):
|
||||
# Backoff still in effect — defer to a later invocation without
|
||||
# burning an attempt or an API call.
|
||||
checkpoint.has_more = True
|
||||
return
|
||||
|
||||
pending_call_ids = list(checkpoint.pending_transcripts.keys())
|
||||
resolved = self._get_call_details_by_ids(pending_call_ids)
|
||||
|
||||
for call_id, details in resolved.items():
|
||||
transcript = checkpoint.pending_transcripts.pop(call_id, None)
|
||||
if transcript is None:
|
||||
continue
|
||||
yield self._build_document(transcript, details)
|
||||
|
||||
if not checkpoint.pending_transcripts:
|
||||
checkpoint.pending_call_details_attempts = 0
|
||||
checkpoint.pending_retry_after = None
|
||||
return
|
||||
|
||||
checkpoint.pending_call_details_attempts += 1
|
||||
logger.warning(
|
||||
f"Gong call details still missing after "
|
||||
f"{checkpoint.pending_call_details_attempts}/"
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={list(checkpoint.pending_transcripts.keys())}"
|
||||
)
|
||||
|
||||
if checkpoint.pending_call_details_attempts >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
logger.error(
|
||||
f"Giving up on missing Gong call details after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={list(checkpoint.pending_transcripts.keys())}"
|
||||
)
|
||||
for call_id in list(checkpoint.pending_transcripts.keys()):
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(document_id=call_id),
|
||||
failure_message=(
|
||||
f"Couldn't get call details after {self.MAX_CALL_DETAILS_ATTEMPTS} attempts for Call ID: {call_id}"
|
||||
),
|
||||
)
|
||||
checkpoint.pending_transcripts = {}
|
||||
checkpoint.pending_call_details_attempts = 0
|
||||
checkpoint.pending_retry_after = None
|
||||
# has_more is recomputed by the workspace iteration that follows;
|
||||
# reset to False here so a stale True from a prior invocation
|
||||
# can't leak out via any future early-return path.
|
||||
checkpoint.has_more = False
|
||||
else:
|
||||
checkpoint.pending_retry_after = time.time() + self._next_retry_delay(
|
||||
checkpoint.pending_call_details_attempts
|
||||
)
|
||||
checkpoint.has_more = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -578,18 +578,8 @@ class GoogleDriveConnector(
|
||||
current_id, file.user_email, field_type, failed_folder_ids_by_email
|
||||
)
|
||||
if not folder:
|
||||
# Can't access this folder - stop climbing.
|
||||
# If the terminal node is a confirmed orphan, backfill all
|
||||
# intermediate folders into failed_folder_ids_by_email so
|
||||
# future files short-circuit via _get_folder_metadata's
|
||||
# cache check instead of re-climbing the whole chain.
|
||||
if failed_folder_ids_by_email is not None:
|
||||
for email in {file.user_email, self.primary_admin_email}:
|
||||
email_failed_ids = failed_folder_ids_by_email.get(email)
|
||||
if email_failed_ids and current_id in email_failed_ids:
|
||||
failed_folder_ids_by_email.setdefault(
|
||||
email, ThreadSafeSet()
|
||||
).update(set(node_ids_in_walk))
|
||||
# Can't access this folder - stop climbing
|
||||
# Don't mark as fully walked since we didn't reach root
|
||||
break
|
||||
|
||||
folder_parent_id = _get_parent_id_from_file(folder)
|
||||
|
||||
@@ -379,20 +379,10 @@ def _download_and_extract_sections_basic(
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
or is_tabular_file(file_name)
|
||||
):
|
||||
# Google Drive doesn't enforce file extensions, so the filename may not
|
||||
# end in .xlsx even when the mime type says it's one. Synthesize the
|
||||
# extension so tabular_file_to_sections dispatches correctly.
|
||||
tabular_file_name = file_name
|
||||
if (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
and not is_tabular_file(file_name)
|
||||
):
|
||||
tabular_file_name = f"{file_name}.xlsx"
|
||||
return list(
|
||||
tabular_file_to_sections(
|
||||
io.BytesIO(response_call()),
|
||||
file_name=tabular_file_name,
|
||||
file_name=file_name,
|
||||
link=link,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -167,12 +167,9 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
default_factory=ThreadSafeSet
|
||||
)
|
||||
|
||||
# Maps email → set of folder IDs that email should skip when walking the
|
||||
# parent chain. Covers two cases:
|
||||
# 1. Folders where that email confirmed no accessible parent (true orphans).
|
||||
# 2. Intermediate folders on a path that dead-ended at a confirmed orphan —
|
||||
# backfilled so future walks short-circuit earlier in the chain.
|
||||
# In both cases _get_folder_metadata skips the API call and returns None.
|
||||
# Maps email → set of IDs of folders where that email confirmed no accessible parent.
|
||||
# Avoids redundant API calls when the same (folder, email) pair is
|
||||
# encountered again within the same retrieval run.
|
||||
failed_folder_ids_by_email: ThreadSafeDict[str, ThreadSafeSet[str]] = Field(
|
||||
default_factory=ThreadSafeDict
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.file_store.staging import RawFileCallback
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
@@ -43,9 +42,6 @@ class NormalizationResult(BaseModel):
|
||||
class BaseConnector(abc.ABC, Generic[CT]):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
|
||||
# Optional raw-file persistence hook to save original file
|
||||
raw_file_callback: RawFileCallback | None = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
@@ -92,15 +88,6 @@ class BaseConnector(abc.ABC, Generic[CT]):
|
||||
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||
based on the application level image analysis setting."""
|
||||
|
||||
def set_raw_file_callback(self, callback: RawFileCallback) -> None:
|
||||
"""Inject the per-attempt raw-file persistence callback.
|
||||
|
||||
Wired up by the docfetching entrypoint via `instantiate_connector`.
|
||||
Connectors that don't care about persisting raw bytes can ignore this
|
||||
— `raw_file_callback` simply stays `None`.
|
||||
"""
|
||||
self.raw_file_callback = callback
|
||||
|
||||
@classmethod
|
||||
def normalize_url(cls, url: str) -> "NormalizationResult": # noqa: ARG003
|
||||
"""Normalize a URL to match the canonical Document.id format used during ingestion.
|
||||
|
||||
@@ -62,19 +62,17 @@ def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts: list[str] = []
|
||||
|
||||
def _extract(node: dict) -> None:
|
||||
if node.get("type") == "text":
|
||||
text = node.get("text", "")
|
||||
if text:
|
||||
texts.append(text)
|
||||
for child in node.get("content", []):
|
||||
_extract(child)
|
||||
|
||||
if adf is not None:
|
||||
_extract(adf)
|
||||
# TODO: complete this function
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
texts.append(item["text"])
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
|
||||
@@ -231,8 +231,6 @@ class DocumentBase(BaseModel):
|
||||
# Set during docfetching after hierarchy nodes are cached
|
||||
parent_hierarchy_node_id: int | None = None
|
||||
|
||||
file_id: str | None = None
|
||||
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
) -> str | None:
|
||||
@@ -372,7 +370,6 @@ class Document(DocumentBase):
|
||||
secondary_owners=base.secondary_owners,
|
||||
title=base.title,
|
||||
from_ingestion_api=base.from_ingestion_api,
|
||||
file_id=base.file_id,
|
||||
)
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
|
||||
@@ -75,8 +75,6 @@ from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
logger = setup_logger()
|
||||
SLIM_BATCH_SIZE = 1000
|
||||
@@ -983,42 +981,6 @@ class SharepointConnector(
|
||||
raise ConnectorValidationError(
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
|
||||
)
|
||||
try:
|
||||
validate_outbound_http_url(site_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise ConnectorValidationError(
|
||||
f"Invalid site URL '{site_url}': {e}"
|
||||
) from e
|
||||
|
||||
# Probe RoleAssignments permission — required for permission sync.
|
||||
# Only runs when credentials have been loaded.
|
||||
if self.msal_app and self.sp_tenant_domain and self.sites:
|
||||
try:
|
||||
token_response = acquire_token_for_rest(
|
||||
self.msal_app,
|
||||
self.sp_tenant_domain,
|
||||
self.sharepoint_domain_suffix,
|
||||
)
|
||||
probe_url = (
|
||||
f"{self.sites[0].rstrip('/')}/_api/web/roleassignments?$top=1"
|
||||
)
|
||||
resp = requests.get(
|
||||
probe_url,
|
||||
headers={"Authorization": f"Bearer {token_response.accessToken}"},
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code in (401, 403):
|
||||
raise ConnectorValidationError(
|
||||
"The Azure AD app registration is missing the required SharePoint permission "
|
||||
"to read role assignments. Please grant 'Sites.FullControl.All' "
|
||||
"(application permission) in the Azure portal and re-run admin consent."
|
||||
)
|
||||
except ConnectorValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"RoleAssignments permission probe failed (non-blocking): {e}"
|
||||
)
|
||||
|
||||
def _extract_tenant_domain_from_sites(self) -> str | None:
|
||||
"""Extract the tenant domain from configured site URLs.
|
||||
@@ -1914,22 +1876,16 @@ class SharepointConnector(
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
)
|
||||
try:
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process site page "
|
||||
f"{site_page.get('webUrl', site_page.get('name', 'Unknown'))}: {e}"
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
@@ -2002,7 +1958,8 @@ class SharepointConnector(
|
||||
self._graph_client = GraphClient(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
return None
|
||||
|
||||
def _get_drive_names_for_site(self, site_url: str) -> list[str]:
|
||||
|
||||
@@ -19,7 +19,6 @@ from playwright.sync_api import Playwright
|
||||
from playwright.sync_api import sync_playwright
|
||||
from playwright.sync_api import TimeoutError
|
||||
from requests_oauthlib import OAuth2Session
|
||||
from typing_extensions import override
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -33,16 +32,11 @@ from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from onyx.utils.web_content import extract_pdf_text
|
||||
@@ -61,6 +55,8 @@ class ScrapeSessionContext:
|
||||
self.visited_links: set[str] = set()
|
||||
self.content_hashes: set[int] = set()
|
||||
|
||||
self.doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
self.at_least_one_doc: bool = False
|
||||
self.last_error: str | None = None
|
||||
self.needs_retry: bool = False
|
||||
@@ -442,7 +438,7 @@ def _handle_cookies(context: BrowserContext, url: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
class WebConnector(LoadConnector, SlimConnector):
|
||||
class WebConnector(LoadConnector):
|
||||
MAX_RETRIES = 3
|
||||
|
||||
def __init__(
|
||||
@@ -497,14 +493,8 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
index: int,
|
||||
initial_url: str,
|
||||
session_ctx: ScrapeSessionContext,
|
||||
slim: bool = False,
|
||||
) -> ScrapeResult:
|
||||
"""Returns a ScrapeResult object with a doc and retry flag.
|
||||
|
||||
When slim=True, skips scroll, PDF content download, and content extraction.
|
||||
The bot-detection render wait (5s) fires on CF/403 responses regardless of slim.
|
||||
networkidle is always awaited so JS-rendered links are discovered correctly.
|
||||
"""
|
||||
"""Returns a ScrapeResult object with a doc and retry flag."""
|
||||
|
||||
if session_ctx.playwright is None:
|
||||
raise RuntimeError("scrape_context.playwright is None")
|
||||
@@ -525,16 +515,7 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
is_pdf = is_pdf_resource(initial_url, content_type)
|
||||
|
||||
if is_pdf:
|
||||
if slim:
|
||||
result.doc = Document(
|
||||
id=initial_url,
|
||||
sections=[],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=initial_url,
|
||||
metadata={},
|
||||
)
|
||||
return result
|
||||
|
||||
# PDF files are not checked for links
|
||||
response = requests.get(initial_url, headers=DEFAULT_HEADERS)
|
||||
page_text, metadata = extract_pdf_text(response.content)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
@@ -565,20 +546,14 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
timeout=30000, # 30 seconds
|
||||
wait_until="commit", # Wait for navigation to commit
|
||||
)
|
||||
# Give the page a moment to start rendering after navigation commits.
|
||||
# Allows CloudFlare and other bot-detection challenges to complete.
|
||||
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
|
||||
|
||||
# Bot-detection JS challenges (CloudFlare, Imperva, etc.) need a moment
|
||||
# to start network activity after commit before networkidle is meaningful.
|
||||
# We detect this via the cf-ray header (CloudFlare) or a 403 response,
|
||||
# which is the common entry point for JS-challenge-based bot detection.
|
||||
is_bot_challenge = page_response is not None and (
|
||||
page_response.header_value("cf-ray") is not None
|
||||
or page_response.status == 403
|
||||
)
|
||||
if is_bot_challenge:
|
||||
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
|
||||
|
||||
# Wait for network activity to settle (handles SPAs, CF challenges, etc.)
|
||||
# Wait for network activity to settle so SPAs that fetch content
|
||||
# asynchronously after the initial JS bundle have time to render.
|
||||
try:
|
||||
# A bit of extra time to account for long-polling, websockets, etc.
|
||||
page.wait_for_load_state("networkidle", timeout=PAGE_RENDER_TIMEOUT_MS)
|
||||
except TimeoutError:
|
||||
pass
|
||||
@@ -601,7 +576,7 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
session_ctx.visited_links.add(initial_url)
|
||||
|
||||
# If we got here, the request was successful
|
||||
if not slim and self.scroll_before_scraping:
|
||||
if self.scroll_before_scraping:
|
||||
scroll_attempts = 0
|
||||
previous_height = page.evaluate("document.body.scrollHeight")
|
||||
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
|
||||
@@ -640,16 +615,6 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
result.retry = True
|
||||
return result
|
||||
|
||||
if slim:
|
||||
result.doc = Document(
|
||||
id=initial_url,
|
||||
sections=[],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=initial_url,
|
||||
metadata={},
|
||||
)
|
||||
return result
|
||||
|
||||
# after this point, we don't need the caller to retry
|
||||
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
|
||||
|
||||
@@ -701,13 +666,9 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
|
||||
return result
|
||||
|
||||
def load_from_state(self, slim: bool = False) -> GenerateDocumentsOutput:
|
||||
"""Traverses through all pages found on the website and converts them into
|
||||
documents.
|
||||
|
||||
When slim=True, yields SlimDocument objects (URL id only, no content).
|
||||
Playwright is used in all modes — slim skips content extraction only.
|
||||
"""
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Traverses through all pages found on the website
|
||||
and converts them into documents"""
|
||||
|
||||
if not self.to_visit_list:
|
||||
raise ValueError("No URLs to visit")
|
||||
@@ -718,8 +679,6 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
session_ctx = ScrapeSessionContext(base_url, self.to_visit_list)
|
||||
session_ctx.initialize()
|
||||
|
||||
batch: list[Document | SlimDocument | HierarchyNode] = []
|
||||
|
||||
while session_ctx.to_visit:
|
||||
initial_url = session_ctx.to_visit.pop()
|
||||
if initial_url in session_ctx.visited_links:
|
||||
@@ -734,9 +693,7 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
continue
|
||||
|
||||
index = len(session_ctx.visited_links)
|
||||
logger.info(
|
||||
f"{index}: {'Slim-visiting' if slim else 'Visiting'} {initial_url}"
|
||||
)
|
||||
logger.info(f"{index}: Visiting {initial_url}")
|
||||
|
||||
# Add retry mechanism with exponential backoff
|
||||
retry_count = 0
|
||||
@@ -751,14 +708,12 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
time.sleep(delay)
|
||||
|
||||
try:
|
||||
result = self._do_scrape(index, initial_url, session_ctx, slim=slim)
|
||||
result = self._do_scrape(index, initial_url, session_ctx)
|
||||
if result.retry:
|
||||
continue
|
||||
|
||||
if result.doc:
|
||||
batch.append(
|
||||
SlimDocument(id=result.doc.id) if slim else result.doc
|
||||
)
|
||||
session_ctx.doc_batch.append(result.doc)
|
||||
except Exception as e:
|
||||
session_ctx.last_error = f"Failed to fetch '{initial_url}': {e}"
|
||||
logger.exception(session_ctx.last_error)
|
||||
@@ -769,16 +724,16 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
|
||||
break # success / don't retry
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
if len(session_ctx.doc_batch) >= self.batch_size:
|
||||
session_ctx.initialize()
|
||||
session_ctx.at_least_one_doc = True
|
||||
yield batch # ty: ignore[invalid-yield]
|
||||
batch = []
|
||||
yield session_ctx.doc_batch
|
||||
session_ctx.doc_batch = []
|
||||
|
||||
if batch:
|
||||
if session_ctx.doc_batch:
|
||||
session_ctx.stop()
|
||||
session_ctx.at_least_one_doc = True
|
||||
yield batch # ty: ignore[invalid-yield]
|
||||
yield session_ctx.doc_batch
|
||||
|
||||
if not session_ctx.at_least_one_doc:
|
||||
if session_ctx.last_error:
|
||||
@@ -787,22 +742,6 @@ class WebConnector(LoadConnector, SlimConnector):
|
||||
|
||||
session_ctx.stop()
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""Yields SlimDocuments for all pages reachable from the configured URLs.
|
||||
|
||||
Uses the same Playwright crawl as full indexing but skips content extraction,
|
||||
scroll, and PDF downloads. The 5s render wait fires only on bot-detection
|
||||
responses (CloudFlare cf-ray header or HTTP 403).
|
||||
The start/end parameters are ignored — WEB connector has no incremental path.
|
||||
"""
|
||||
yield from self.load_from_state(slim=True) # ty: ignore[invalid-yield]
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Make sure we have at least one valid URL to check
|
||||
if not self.to_visit_list:
|
||||
|
||||
@@ -244,21 +244,13 @@ def fetch_latest_index_attempts_by_status(
|
||||
return query.all()
|
||||
|
||||
|
||||
_INTERNAL_ONLY_SOURCES = {
|
||||
# Used by the ingestion API, not a user-created connector.
|
||||
DocumentSource.INGESTION_API,
|
||||
# Backs the user library / build feature, not a connector users filter by.
|
||||
DocumentSource.CRAFT_FILE,
|
||||
}
|
||||
|
||||
|
||||
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
||||
distinct_sources = db_session.query(Connector.source).distinct().all()
|
||||
|
||||
sources = [
|
||||
source[0]
|
||||
for source in distinct_sources
|
||||
if source[0] not in _INTERNAL_ONLY_SOURCES
|
||||
if source[0] != DocumentSource.INGESTION_API
|
||||
]
|
||||
|
||||
return sources
|
||||
|
||||
@@ -696,7 +696,6 @@ def upsert_documents(
|
||||
else {}
|
||||
),
|
||||
doc_metadata=doc.doc_metadata,
|
||||
file_id=doc.file_id,
|
||||
)
|
||||
)
|
||||
for doc in seen_documents.values()
|
||||
@@ -713,7 +712,6 @@ def upsert_documents(
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
"doc_metadata": insert_stmt.excluded.doc_metadata,
|
||||
"parent_hierarchy_node_id": insert_stmt.excluded.parent_hierarchy_node_id,
|
||||
"file_id": insert_stmt.excluded.file_id,
|
||||
}
|
||||
if includes_permissions:
|
||||
# Use COALESCE to preserve existing permissions when new values are NULL.
|
||||
|
||||
@@ -62,21 +62,6 @@ def delete_filerecord_by_file_id(
|
||||
db_session.query(FileRecord).filter_by(file_id=file_id).delete()
|
||||
|
||||
|
||||
def update_filerecord_origin(
|
||||
file_id: str,
|
||||
from_origin: FileOrigin,
|
||||
to_origin: FileOrigin,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Change a file_record's `file_origin`, filtered on the current origin
|
||||
so the update is idempotent. Caller owns the commit.
|
||||
"""
|
||||
db_session.query(FileRecord).filter(
|
||||
FileRecord.file_id == file_id,
|
||||
FileRecord.file_origin == from_origin,
|
||||
).update({FileRecord.file_origin: to_origin})
|
||||
|
||||
|
||||
def upsert_filerecord(
|
||||
file_id: str,
|
||||
display_name: str,
|
||||
|
||||
@@ -952,7 +952,6 @@ class Document(Base):
|
||||
semantic_id: Mapped[str] = mapped_column(NullFilteredString)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
file_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# The updated time is also used as a measure of the last successful state of the doc
|
||||
# pulled from the source (to help skip reindexing already updated docs in case of
|
||||
@@ -1055,9 +1054,9 @@ class Document(Base):
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_document_needs_sync",
|
||||
"id",
|
||||
postgresql_where=text("last_modified > last_synced OR last_synced IS NULL"),
|
||||
"ix_document_sync_status",
|
||||
last_modified,
|
||||
last_synced,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import OpenSearchDocumentMigrationRecord
|
||||
@@ -413,11 +412,7 @@ def get_opensearch_retrieval_state(
|
||||
|
||||
If the tenant migration record is not found, defaults to
|
||||
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX.
|
||||
|
||||
If ONYX_DISABLE_VESPA is True, always returns True.
|
||||
"""
|
||||
if ONYX_DISABLE_VESPA:
|
||||
return True
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
return ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
|
||||
@@ -129,13 +129,9 @@ def update_sync_record_status(
|
||||
"""
|
||||
sync_record = fetch_latest_sync_record(db_session, entity_id, sync_type)
|
||||
if sync_record is None:
|
||||
logger.warning(
|
||||
f"No sync record found for entity_id={entity_id} "
|
||||
f"sync_type={sync_type} — skipping status update. "
|
||||
f"This typically means the record was never created "
|
||||
f"(insert_sync_record failed silently) or was cleaned up."
|
||||
raise ValueError(
|
||||
f"No sync record found for entity_id={entity_id} sync_type={sync_type}"
|
||||
)
|
||||
return
|
||||
|
||||
sync_record.sync_status = sync_status
|
||||
if num_docs_synced is not None:
|
||||
|
||||
@@ -2,17 +2,11 @@ import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import UserFile
|
||||
@@ -114,73 +108,13 @@ def update_last_accessed_at_for_user_files(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_file_id_by_user_file_id(
|
||||
user_file_id: str, user_id: UUID, db_session: Session
|
||||
) -> str | None:
|
||||
user_file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == user_file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
def get_file_id_by_user_file_id(user_file_id: str, db_session: Session) -> str | None:
|
||||
user_file = db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
|
||||
if user_file:
|
||||
return user_file.file_id
|
||||
return None
|
||||
|
||||
|
||||
def user_can_access_chat_file(file_id: str, user_id: UUID, db_session: Session) -> bool:
|
||||
"""Return True if `user_id` is allowed to read the raw `file_id` served by
|
||||
`GET /chat/file/{file_id}`. Access is granted when any of:
|
||||
|
||||
- The `file_id` is the storage id of a `UserFile` owned by the user.
|
||||
- The `file_id` is a persona avatar (`Persona.uploaded_image_id`); avatars
|
||||
are visible to any authenticated user.
|
||||
- The `file_id` appears in a `ChatMessage.files` descriptor of a chat
|
||||
session the user owns or a session publicly shared via
|
||||
`ChatSessionSharedStatus.PUBLIC`.
|
||||
"""
|
||||
owns_user_file = db_session.query(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.file_id == file_id, UserFile.user_id == user_id)
|
||||
.exists()
|
||||
).scalar()
|
||||
if owns_user_file:
|
||||
return True
|
||||
|
||||
# TODO: move persona avatars to a dedicated endpoint (e.g.
|
||||
# /assistants/{id}/avatar) so this branch can be removed. /chat/file is
|
||||
# currently overloaded with multiple asset classes (user files, chat
|
||||
# attachments, tool outputs, avatars), forcing this access-check fan-out.
|
||||
#
|
||||
# Restrict the avatar path to CHAT_UPLOAD-origin files so an attacker
|
||||
# cannot bind another user's USER_FILE (or any other origin) to their
|
||||
# own persona and read it through this check.
|
||||
is_persona_avatar = db_session.query(
|
||||
select(Persona.id)
|
||||
.join(FileRecord, FileRecord.file_id == Persona.uploaded_image_id)
|
||||
.where(
|
||||
Persona.uploaded_image_id == file_id,
|
||||
FileRecord.file_origin == FileOrigin.CHAT_UPLOAD,
|
||||
)
|
||||
.exists()
|
||||
).scalar()
|
||||
if is_persona_avatar:
|
||||
return True
|
||||
|
||||
chat_file_stmt = (
|
||||
select(ChatMessage.id)
|
||||
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.files.op("@>")([{"id": file_id}]))
|
||||
.where(
|
||||
or_(
|
||||
ChatSession.user_id == user_id,
|
||||
ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC,
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return db_session.execute(chat_file_stmt).first() is not None
|
||||
|
||||
|
||||
def get_file_ids_by_user_file_ids(
|
||||
user_file_ids: list[UUID], db_session: Session
|
||||
) -> list[str]:
|
||||
|
||||
@@ -3,7 +3,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.opensearch_migration import get_opensearch_retrieval_state
|
||||
from onyx.document_index.disabled import DisabledDocumentIndex
|
||||
@@ -49,11 +48,6 @@ def get_default_document_index(
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if ONYX_DISABLE_VESPA:
|
||||
if not opensearch_retrieval_enabled:
|
||||
raise ValueError(
|
||||
"BUG: ONYX_DISABLE_VESPA is set but opensearch_retrieval_enabled is not set."
|
||||
)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
@@ -125,32 +119,21 @@ def get_all_document_indices(
|
||||
)
|
||||
]
|
||||
|
||||
result: list[DocumentIndex] = []
|
||||
|
||||
if ONYX_DISABLE_VESPA:
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
raise ValueError(
|
||||
"ONYX_DISABLE_VESPA is set but ENABLE_OPENSEARCH_INDEXING_FOR_ONYX is not set."
|
||||
)
|
||||
else:
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
result.append(vespa_document_index)
|
||||
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name if secondary_search_settings else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
@@ -186,6 +169,7 @@ def get_all_document_indices(
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
result: list[DocumentIndex] = [vespa_document_index]
|
||||
if opensearch_document_index:
|
||||
result.append(opensearch_document_index)
|
||||
|
||||
return result
|
||||
|
||||
@@ -98,9 +98,6 @@ class DocumentMetadata:
|
||||
# The resolved database ID of the parent hierarchy node (folder/container)
|
||||
parent_hierarchy_node_id: int | None = None
|
||||
|
||||
# Opt-in pointer to the persisted raw file for this document (file_store id).
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VespaDocumentFields:
|
||||
|
||||
@@ -17,7 +17,6 @@ from onyx.configs.app_configs import OPENSEARCH_ADMIN_PASSWORD
|
||||
from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
|
||||
from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
|
||||
from onyx.configs.app_configs import OPENSEARCH_USE_SSL
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
@@ -133,7 +132,7 @@ class OpenSearchClient(AbstractContextManager):
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = OPENSEARCH_USE_SSL,
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
@@ -303,7 +302,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = OPENSEARCH_USE_SSL,
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
@@ -508,55 +507,8 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
Raises:
|
||||
Exception: There was an error updating the settings of the index.
|
||||
"""
|
||||
logger.debug(f"Updating settings of index {self._index_name} with {settings}.")
|
||||
response = self._client.indices.put_settings(
|
||||
index=self._index_name, body=settings
|
||||
)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(
|
||||
f"Failed to update settings of index {self._index_name}."
|
||||
)
|
||||
logger.debug(f"Settings of index {self._index_name} updated successfully.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def get_settings(self) -> dict[str, Any]:
|
||||
"""Gets the settings of the index.
|
||||
|
||||
Returns:
|
||||
The settings of the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error getting the settings of the index.
|
||||
"""
|
||||
logger.debug(f"Getting settings of index {self._index_name}.")
|
||||
response = self._client.indices.get_settings(index=self._index_name)
|
||||
return response[self._index_name]["settings"]
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def open_index(self) -> None:
|
||||
"""Opens the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error opening the index.
|
||||
"""
|
||||
logger.debug(f"Opening index {self._index_name}.")
|
||||
response = self._client.indices.open(index=self._index_name)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to open index {self._index_name}.")
|
||||
logger.debug(f"Index {self._index_name} opened successfully.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close_index(self) -> None:
|
||||
"""Closes the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the index.
|
||||
"""
|
||||
logger.debug(f"Closing index {self._index_name}.")
|
||||
response = self._client.indices.close(index=self._index_name)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to close index {self._index_name}.")
|
||||
logger.debug(f"Index {self._index_name} closed successfully.")
|
||||
# TODO(andrei): Implement this.
|
||||
raise NotImplementedError
|
||||
|
||||
@log_function_time(
|
||||
print_only=True,
|
||||
|
||||
@@ -12,7 +12,6 @@ from email.parser import Parser as EmailParser
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
@@ -21,11 +20,10 @@ from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import openpyxl
|
||||
from openpyxl.worksheet._read_only import ReadOnlyWorksheet
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_XLSX_CELLS_PER_SHEET
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -50,8 +48,6 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
"File contains no valid workbook part",
|
||||
"Unable to read workbook: could not read stylesheet from None",
|
||||
"Colors must be aRGB hex values",
|
||||
"Max value is",
|
||||
"There is no item named",
|
||||
]
|
||||
|
||||
|
||||
@@ -368,40 +364,6 @@ def extract_docx_images(docx_bytes: IO[Any]) -> Iterator[tuple[bytes, str]]:
|
||||
logger.exception("Failed to extract all docx images")
|
||||
|
||||
|
||||
def count_docx_embedded_images(file: IO[Any], cap: int) -> int:
|
||||
"""Return the number of embedded images in a docx, short-circuiting at cap+1.
|
||||
|
||||
Mirrors count_pdf_embedded_images so upload validation can apply the same
|
||||
per-file/per-batch caps. Returns a value > cap once the count exceeds the
|
||||
cap so callers do not iterate every media entry just to report a number.
|
||||
Always restores the file pointer to its original position before returning.
|
||||
"""
|
||||
try:
|
||||
start_pos = file.tell()
|
||||
except Exception:
|
||||
start_pos = None
|
||||
try:
|
||||
if start_pos is not None:
|
||||
file.seek(0)
|
||||
count = 0
|
||||
with zipfile.ZipFile(file) as z:
|
||||
for name in z.namelist():
|
||||
if name.startswith("word/media/"):
|
||||
count += 1
|
||||
if count > cap:
|
||||
return count
|
||||
return count
|
||||
except Exception:
|
||||
logger.warning("Failed to count embedded images in docx", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
if start_pos is not None:
|
||||
try:
|
||||
file.seek(start_pos)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def read_docx_file(
|
||||
file: IO[Any],
|
||||
file_name: str = "",
|
||||
@@ -484,83 +446,104 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
return presentation.markdown
|
||||
|
||||
|
||||
def _columns_to_keep(col_has_data: bytearray, max_empty: int) -> list[int]:
|
||||
"""Keep non-empty columns, plus runs of up to ``max_empty`` empty columns
|
||||
between them. Trailing empty columns are dropped."""
|
||||
def _worksheet_to_matrix(
|
||||
worksheet: Worksheet,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Converts a singular worksheet to a matrix of values.
|
||||
|
||||
Rows are padded to a uniform width. In openpyxl's read_only mode,
|
||||
iter_rows can yield rows of differing lengths (trailing empty cells
|
||||
are sometimes omitted), and downstream column cleanup assumes a
|
||||
rectangular matrix.
|
||||
"""
|
||||
rows: list[list[str]] = []
|
||||
max_len = 0
|
||||
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
|
||||
row = ["" if cell is None else str(cell) for cell in worksheet_row]
|
||||
if len(row) > max_len:
|
||||
max_len = len(row)
|
||||
rows.append(row)
|
||||
|
||||
for row in rows:
|
||||
if len(row) < max_len:
|
||||
row.extend([""] * (max_len - len(row)))
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
|
||||
"""
|
||||
Cleans a worksheet matrix by removing rows if there are N consecutive empty
|
||||
rows and removing cols if there are M consecutive empty columns
|
||||
"""
|
||||
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
|
||||
MAX_EMPTY_COLS = 2
|
||||
|
||||
# Row cleanup
|
||||
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
|
||||
|
||||
if not matrix:
|
||||
return matrix
|
||||
|
||||
# Column cleanup — determine which columns to keep without transposing.
|
||||
num_cols = len(matrix[0])
|
||||
keep_cols = _columns_to_keep(matrix, num_cols, max_empty=MAX_EMPTY_COLS)
|
||||
if len(keep_cols) < num_cols:
|
||||
matrix = [[row[c] for c in keep_cols] for row in matrix]
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _columns_to_keep(
|
||||
matrix: list[list[str]], num_cols: int, max_empty: int
|
||||
) -> list[int]:
|
||||
"""Return the indices of columns to keep after removing empty-column runs.
|
||||
|
||||
Uses the same logic as ``_remove_empty_runs`` but operates on column
|
||||
indices so no transpose is needed.
|
||||
"""
|
||||
kept: list[int] = []
|
||||
empty_buffer: list[int] = []
|
||||
for c, has in enumerate(col_has_data):
|
||||
if has:
|
||||
kept.extend(empty_buffer[:max_empty])
|
||||
kept.append(c)
|
||||
empty_buffer = []
|
||||
|
||||
for col_idx in range(num_cols):
|
||||
col_is_empty = all(not row[col_idx] for row in matrix)
|
||||
if col_is_empty:
|
||||
empty_buffer.append(col_idx)
|
||||
else:
|
||||
empty_buffer.append(c)
|
||||
kept.extend(empty_buffer[:max_empty])
|
||||
kept.append(col_idx)
|
||||
empty_buffer = []
|
||||
|
||||
return kept
|
||||
|
||||
|
||||
def _sheet_to_csv(rows: Iterator[tuple[Any, ...]]) -> str:
|
||||
"""Stream worksheet rows into CSV text without materializing a dense matrix.
|
||||
def _remove_empty_runs(
|
||||
rows: list[list[str]],
|
||||
max_empty: int,
|
||||
) -> list[list[str]]:
|
||||
"""Removes entire runs of empty rows when the run length exceeds max_empty.
|
||||
|
||||
Empty rows are never stored. Column occupancy is tracked as a ``bytearray``
|
||||
bitmap so column trimming needs no transpose or copy. Runs of empty
|
||||
rows/columns longer than 2 are collapsed; shorter runs are preserved.
|
||||
|
||||
Scanning stops once ``MAX_XLSX_CELLS_PER_SHEET`` non-empty cells have been
|
||||
seen; the output gets a truncation marker row appended so downstream
|
||||
indexing sees that the sheet was cut off.
|
||||
Leading empty runs are capped to max_empty, just like interior runs.
|
||||
Trailing empty rows are always dropped since there is no subsequent
|
||||
non-empty row to flush them.
|
||||
"""
|
||||
MAX_EMPTY_ROWS_IN_OUTPUT = 2
|
||||
MAX_EMPTY_COLS_IN_OUTPUT = 2
|
||||
TRUNCATION_MARKER = "[truncated: sheet exceeded cell limit]"
|
||||
result: list[list[str]] = []
|
||||
empty_buffer: list[list[str]] = []
|
||||
|
||||
non_empty_rows: list[tuple[int, list[str]]] = []
|
||||
col_has_data = bytearray()
|
||||
total_non_empty = 0
|
||||
truncated = False
|
||||
for row in rows:
|
||||
# Check if empty
|
||||
if not any(row):
|
||||
if len(empty_buffer) < max_empty:
|
||||
empty_buffer.append(row)
|
||||
else:
|
||||
# Add upto max empty rows onto the result - that's what we allow
|
||||
result.extend(empty_buffer[:max_empty])
|
||||
# Add the new non-empty row
|
||||
result.append(row)
|
||||
empty_buffer = []
|
||||
|
||||
for row_idx, row_vals in enumerate(rows):
|
||||
# Fast-reject empty rows before allocating a list of "".
|
||||
if not any(v is not None and v != "" for v in row_vals):
|
||||
continue
|
||||
|
||||
cells = ["" if v is None else str(v) for v in row_vals]
|
||||
non_empty_rows.append((row_idx, cells))
|
||||
|
||||
if len(cells) > len(col_has_data):
|
||||
col_has_data.extend(b"\x00" * (len(cells) - len(col_has_data)))
|
||||
for i, v in enumerate(cells):
|
||||
if v:
|
||||
col_has_data[i] = 1
|
||||
total_non_empty += 1
|
||||
|
||||
if total_non_empty > MAX_XLSX_CELLS_PER_SHEET:
|
||||
truncated = True
|
||||
break
|
||||
|
||||
if not non_empty_rows:
|
||||
return ""
|
||||
|
||||
keep_cols = _columns_to_keep(col_has_data, MAX_EMPTY_COLS_IN_OUTPUT)
|
||||
if not keep_cols:
|
||||
return ""
|
||||
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
blank_row = [""] * len(keep_cols)
|
||||
last_idx = -1
|
||||
for row_idx, cells in non_empty_rows:
|
||||
gap = row_idx - last_idx - 1
|
||||
if gap > 0:
|
||||
for _ in range(min(gap, MAX_EMPTY_ROWS_IN_OUTPUT)):
|
||||
writer.writerow(blank_row)
|
||||
writer.writerow([cells[c] if c < len(cells) else "" for c in keep_cols])
|
||||
last_idx = row_idx
|
||||
|
||||
if truncated:
|
||||
writer.writerow([TRUNCATION_MARKER])
|
||||
|
||||
return buf.getvalue().rstrip("\n")
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str, str]]:
|
||||
@@ -588,24 +571,20 @@ def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str,
|
||||
raise
|
||||
|
||||
sheets: list[tuple[str, str]] = []
|
||||
try:
|
||||
for sheet in workbook.worksheets:
|
||||
# Declared dimensions can be different to what is actually there
|
||||
ro_sheet = cast(ReadOnlyWorksheet, sheet)
|
||||
ro_sheet.reset_dimensions()
|
||||
csv_text = _sheet_to_csv(ro_sheet.iter_rows(values_only=True))
|
||||
sheets.append((csv_text.strip(), ro_sheet.title))
|
||||
finally:
|
||||
workbook.close()
|
||||
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
csv_text = buf.getvalue().rstrip("\n")
|
||||
if csv_text.strip():
|
||||
sheets.append((csv_text, sheet.title))
|
||||
return sheets
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
sheets = xlsx_sheet_extraction(file, file_name)
|
||||
return TEXT_SECTION_SEPARATOR.join(
|
||||
csv_text for csv_text, _title in sheets if csv_text
|
||||
)
|
||||
return TEXT_SECTION_SEPARATOR.join(csv_text for csv_text, _title in sheets)
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.file_record import update_filerecord_origin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# (content, content_type) -> file_id
|
||||
RawFileCallback = Callable[[IO[bytes], str], str]
|
||||
|
||||
|
||||
def stage_raw_file(
|
||||
content: IO,
|
||||
content_type: str,
|
||||
*,
|
||||
metadata: dict[str, Any],
|
||||
) -> str:
|
||||
"""Persist raw bytes to the file store with FileOrigin.INDEXING_STAGING.
|
||||
|
||||
`metadata` is attached to the file_record so that downstream promotion
|
||||
(in docprocessing) and orphan reaping (TTL janitor) can locate the file
|
||||
by its originating context.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
file_id = file_store.save_file(
|
||||
content=content,
|
||||
display_name=None,
|
||||
file_origin=FileOrigin.INDEXING_STAGING,
|
||||
file_type=content_type,
|
||||
file_metadata=metadata,
|
||||
)
|
||||
return file_id
|
||||
|
||||
|
||||
def build_raw_file_callback(
|
||||
*,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
) -> RawFileCallback:
|
||||
"""Build a per-attempt callback that connectors can invoke to opt in to
|
||||
raw-file persistence. The closure binds the attempt-level context as the
|
||||
staging metadata so the connector only needs to pass per-call info
|
||||
(bytes, content_type) and gets back a file_id to attach to its Document.
|
||||
"""
|
||||
metadata: dict[str, Any] = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
|
||||
def _callback(content: IO[bytes], content_type: str) -> str:
|
||||
return stage_raw_file(
|
||||
content=content,
|
||||
content_type=content_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return _callback
|
||||
|
||||
|
||||
def promote_staged_file(db_session: Session, file_id: str) -> None:
|
||||
"""Mark a previously-staged file as `FileOrigin.CONNECTOR`."""
|
||||
update_filerecord_origin(
|
||||
file_id=file_id,
|
||||
from_origin=FileOrigin.INDEXING_STAGING,
|
||||
to_origin=FileOrigin.CONNECTOR,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -30,7 +30,6 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SectionType
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
@@ -50,7 +49,6 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.staging import promote_staged_file
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
@@ -156,7 +154,6 @@ def _upsert_documents_in_db(
|
||||
doc_metadata=doc.doc_metadata,
|
||||
# parent_hierarchy_node_id is resolved in docfetching using Redis cache
|
||||
parent_hierarchy_node_id=doc.parent_hierarchy_node_id,
|
||||
file_id=doc.file_id,
|
||||
)
|
||||
document_metadata_list.append(db_doc_metadata)
|
||||
|
||||
@@ -367,45 +364,6 @@ def index_doc_batch_with_handler(
|
||||
return index_pipeline_result
|
||||
|
||||
|
||||
def _promote_new_staged_files(
|
||||
documents: list[Document],
|
||||
previous_file_ids: dict[str, str],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Queue STAGING → CONNECTOR origin flips for every new file_id in the batch.
|
||||
|
||||
Intended to run immediately before `_upsert_documents_in_db` so the origin
|
||||
flip lands in the same commit as the `Document.file_id` write. Does not
|
||||
commit — the caller's next commit flushes these UPDATEs.
|
||||
"""
|
||||
for doc in documents:
|
||||
new_file_id = doc.file_id
|
||||
if new_file_id is None or new_file_id == previous_file_ids.get(doc.id):
|
||||
continue
|
||||
promote_staged_file(db_session=db_session, file_id=new_file_id)
|
||||
|
||||
|
||||
def _delete_replaced_files(
|
||||
documents: list[Document],
|
||||
previous_file_ids: dict[str, str],
|
||||
) -> None:
|
||||
"""Best-effort blob deletes for file_ids replaced in this batch.
|
||||
|
||||
Must run AFTER `Document.file_id` has been committed to the new
|
||||
file_id.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
for doc in documents:
|
||||
new_file_id = doc.file_id
|
||||
old_file_id = previous_file_ids.get(doc.id)
|
||||
if old_file_id is None or old_file_id == new_file_id:
|
||||
continue
|
||||
try:
|
||||
file_store.delete_file(old_file_id, error_on_missing=False)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to delete replaced file_id={old_file_id}.")
|
||||
|
||||
|
||||
def index_doc_batch_prepare(
|
||||
documents: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
@@ -424,11 +382,6 @@ def index_doc_batch_prepare(
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Capture previous file_ids BEFORE any writes so we know what to reap.
|
||||
previous_file_ids: dict[str, str] = {
|
||||
db_doc.id: db_doc.file_id for db_doc in db_docs if db_doc.file_id is not None
|
||||
}
|
||||
|
||||
updatable_docs = (
|
||||
get_doc_ids_to_update(documents=documents, db_docs=db_docs)
|
||||
if not ignore_time_skip
|
||||
@@ -446,24 +399,11 @@ def index_doc_batch_prepare(
|
||||
# for all updatable docs, upsert into the DB
|
||||
# Does not include doc_updated_at which is also used to indicate a successful update
|
||||
if updatable_docs:
|
||||
# Queue the STAGING → CONNECTOR origin flips BEFORE the Document upsert
|
||||
# so `upsert_documents`' commit flushes Document.file_id and the origin
|
||||
# flip atomically
|
||||
_promote_new_staged_files(
|
||||
documents=updatable_docs,
|
||||
previous_file_ids=previous_file_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
_upsert_documents_in_db(
|
||||
documents=updatable_docs,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
)
|
||||
# Blob deletes run only after Document.file_id is durable.
|
||||
_delete_replaced_files(
|
||||
documents=updatable_docs,
|
||||
previous_file_ids=previous_file_ids,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Upserted {len(updatable_docs)} changed docs out of {len(documents)} total docs into the DB"
|
||||
@@ -590,15 +530,8 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
Returns:
|
||||
List of IndexingDocument objects with processed_sections as list[Section]
|
||||
"""
|
||||
# Check if image extraction and analysis is enabled before trying to get a vision LLM.
|
||||
# Use section.type rather than isinstance because sections can round-trip
|
||||
# through pydantic as base Section instances (not the concrete subclass).
|
||||
has_image_section = any(
|
||||
section.type == SectionType.IMAGE
|
||||
for document in documents
|
||||
for section in document.sections
|
||||
)
|
||||
if not get_image_extraction_and_analysis_enabled() or not has_image_section:
|
||||
# Check if image extraction and analysis is enabled before trying to get a vision LLM
|
||||
if not get_image_extraction_and_analysis_enabled():
|
||||
llm = None
|
||||
else:
|
||||
# Only get the vision LLM if image processing is enabled
|
||||
|
||||
@@ -290,7 +290,11 @@ def litellm_exception_to_error_msg(
|
||||
error_code = "BUDGET_EXCEEDED"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, Timeout):
|
||||
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
||||
error_msg = (
|
||||
"The LLM took too long to respond. "
|
||||
"If you're running a local model, try increasing the "
|
||||
"LLM_SOCKET_READ_TIMEOUT environment variable (current default: 120 seconds)."
|
||||
)
|
||||
error_code = "CONNECTION_ERROR"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, APIError):
|
||||
|
||||
@@ -166,7 +166,6 @@ from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.configs import SENTRY_TRACES_SAMPLE_RATE
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
warnings.filterwarnings(
|
||||
@@ -440,7 +439,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=SENTRY_TRACES_SAMPLE_RATE,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
|
||||
@@ -19,14 +19,9 @@ from onyx.configs.app_configs import MCP_SERVER_CORS_ORIGINS
|
||||
from onyx.mcp_server.auth import OnyxTokenVerifier
|
||||
from onyx.mcp_server.utils import shutdown_http_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Initialize EE flag at module import so it's set regardless of the entry point
|
||||
# (python -m onyx.mcp_server_main, uvicorn onyx.mcp_server.api:mcp_app, etc.).
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
logger.info("Creating Onyx MCP Server...")
|
||||
|
||||
mcp_server = FastMCP(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Resource registrations for the Onyx MCP server."""
|
||||
|
||||
# Import resource modules so decorators execute when the package loads.
|
||||
from onyx.mcp_server.resources import document_sets # noqa: F401
|
||||
from onyx.mcp_server.resources import indexed_sources # noqa: F401
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""Resource exposing document sets available to the current user."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_accessible_document_sets
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@mcp_server.resource(
|
||||
"resource://document_sets",
|
||||
name="document_sets",
|
||||
description=(
|
||||
"Enumerate the Document Sets accessible to the current user. Use the "
|
||||
"returned `name` values with the `document_set_names` filter of the "
|
||||
"`search_indexed_documents` tool to scope searches to a specific set."
|
||||
),
|
||||
mime_type="application/json",
|
||||
)
|
||||
async def document_sets_resource() -> str:
|
||||
"""Return the list of document sets the user can filter searches by."""
|
||||
|
||||
access_token = require_access_token()
|
||||
|
||||
document_sets = sorted(
|
||||
await get_accessible_document_sets(access_token), key=lambda entry: entry.name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Onyx MCP Server: document_sets resource returning %s entries",
|
||||
len(document_sets),
|
||||
)
|
||||
|
||||
# FastMCP 3.2+ requires str/bytes/list[ResourceContent] — it no longer
|
||||
# auto-serializes; serialize to JSON ourselves.
|
||||
return json.dumps([entry.model_dump(mode="json") for entry in document_sets])
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_indexed_sources
|
||||
@@ -21,7 +21,7 @@ logger = setup_logger()
|
||||
),
|
||||
mime_type="application/json",
|
||||
)
|
||||
async def indexed_sources_resource() -> str:
|
||||
async def indexed_sources_resource() -> dict[str, Any]:
|
||||
"""Return the list of indexed source types for search filtering."""
|
||||
|
||||
access_token = require_access_token()
|
||||
@@ -33,6 +33,6 @@ async def indexed_sources_resource() -> str:
|
||||
len(sources),
|
||||
)
|
||||
|
||||
# FastMCP 3.2+ requires str/bytes/list[ResourceContent] — it no longer
|
||||
# auto-serializes; serialize to JSON ourselves.
|
||||
return json.dumps(sorted(sources))
|
||||
return {
|
||||
"indexed_sources": sorted(sources),
|
||||
}
|
||||
|
||||
@@ -4,23 +4,12 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_http_client
|
||||
from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolRequest
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolResponse
|
||||
from onyx.server.features.web_search.models import WebSearchToolRequest
|
||||
from onyx.server.features.web_search.models import WebSearchToolResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -28,43 +17,6 @@ from onyx.utils.variable_functionality import global_version
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# CE search falls through to the chat endpoint, which invokes an LLM — the
|
||||
# default 60s client timeout is not enough for a real RAG-backed response.
|
||||
_CE_SEARCH_TIMEOUT_SECONDS = 300.0
|
||||
|
||||
|
||||
async def _post_model(
|
||||
url: str,
|
||||
body: BaseModel,
|
||||
access_token: AccessToken,
|
||||
timeout: float | None = None,
|
||||
) -> httpx.Response:
|
||||
"""POST a Pydantic model as JSON to the Onyx backend."""
|
||||
return await get_http_client().post(
|
||||
url,
|
||||
content=body.model_dump_json(exclude_unset=True),
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token.token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=timeout if timeout is not None else httpx.USE_CLIENT_DEFAULT,
|
||||
)
|
||||
|
||||
|
||||
def _project_doc(doc: SearchDoc, content: str | None) -> dict[str, Any]:
|
||||
"""Project a backend search doc into the MCP wire shape.
|
||||
|
||||
Accepts SearchDocWithContent (EE) too since it extends SearchDoc.
|
||||
"""
|
||||
return {
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"content": content,
|
||||
"source_type": doc.source_type.value,
|
||||
"link": doc.link,
|
||||
"score": doc.score,
|
||||
}
|
||||
|
||||
|
||||
def _extract_error_detail(response: httpx.Response) -> str:
|
||||
"""Extract a human-readable error message from a failed backend response.
|
||||
|
||||
@@ -84,7 +36,6 @@ def _extract_error_detail(response: httpx.Response) -> str:
|
||||
async def search_indexed_documents(
|
||||
query: str,
|
||||
source_types: list[str] | None = None,
|
||||
document_set_names: list[str] | None = None,
|
||||
time_cutoff: str | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
@@ -102,10 +53,6 @@ async def search_indexed_documents(
|
||||
In EE mode, the dedicated search endpoint is used instead.
|
||||
|
||||
To find a list of available sources, use the `indexed_sources` resource.
|
||||
`document_set_names` restricts results to documents belonging to the named
|
||||
Document Sets — useful for scoping queries to a curated subset of the
|
||||
knowledge base (e.g. to isolate knowledge between agents). Use the
|
||||
`document_sets` resource to discover accessible set names.
|
||||
Returns chunks of text as search results with snippets, scores, and metadata.
|
||||
|
||||
Example usage:
|
||||
@@ -113,23 +60,15 @@ async def search_indexed_documents(
|
||||
{
|
||||
"query": "What is the latest status of PROJ-1234 and what is the next development item?",
|
||||
"source_types": ["jira", "google_drive", "github"],
|
||||
"document_set_names": ["Engineering Wiki"],
|
||||
"time_cutoff": "2025-11-24T00:00:00Z",
|
||||
"limit": 10,
|
||||
}
|
||||
```
|
||||
"""
|
||||
logger.info(
|
||||
f"Onyx MCP Server: document search: query='{query}', sources={source_types}, "
|
||||
f"document_sets={document_set_names}, limit={limit}"
|
||||
f"Onyx MCP Server: document search: query='{query}', sources={source_types}, limit={limit}"
|
||||
)
|
||||
|
||||
# Normalize empty list inputs to None so downstream filter construction is
|
||||
# consistent — BaseFilters treats [] as "match zero" which differs from
|
||||
# "no filter" (None).
|
||||
source_types = source_types or None
|
||||
document_set_names = document_set_names or None
|
||||
|
||||
# Parse time_cutoff string to datetime if provided
|
||||
time_cutoff_dt: datetime | None = None
|
||||
if time_cutoff:
|
||||
@@ -142,6 +81,9 @@ async def search_indexed_documents(
|
||||
# Continue with no time_cutoff instead of returning an error
|
||||
time_cutoff_dt = None
|
||||
|
||||
# Initialize source_type_enums early to avoid UnboundLocalError
|
||||
source_type_enums: list[DocumentSource] | None = None
|
||||
|
||||
# Get authenticated user from FastMCP's access token
|
||||
access_token = require_access_token()
|
||||
|
||||
@@ -175,7 +117,6 @@ async def search_indexed_documents(
|
||||
|
||||
# Convert source_types strings to DocumentSource enums if provided
|
||||
# Invalid values will be handled by the API server
|
||||
source_type_enums: list[DocumentSource] | None = None
|
||||
if source_types is not None:
|
||||
source_type_enums = []
|
||||
for src in source_types:
|
||||
@@ -186,83 +127,83 @@ async def search_indexed_documents(
|
||||
f"Onyx MCP Server: Invalid source type '{src}' - will be ignored by server"
|
||||
)
|
||||
|
||||
filters: BaseFilters | None = None
|
||||
if source_type_enums or document_set_names or time_cutoff_dt:
|
||||
filters = BaseFilters(
|
||||
source_type=source_type_enums,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=time_cutoff_dt,
|
||||
)
|
||||
# Build filters dict only with non-None values
|
||||
filters: dict[str, Any] | None = None
|
||||
if source_type_enums or time_cutoff_dt:
|
||||
filters = {}
|
||||
if source_type_enums:
|
||||
filters["source_type"] = [src.value for src in source_type_enums]
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
is_ee = global_version.is_ee_version()
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
|
||||
request: BaseModel
|
||||
search_request: dict[str, Any]
|
||||
if is_ee:
|
||||
# EE: use the dedicated search endpoint (no LLM invocation).
|
||||
# Lazy import so CE deployments that strip ee/ never load this module.
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
|
||||
request = SendSearchQueryRequest(
|
||||
search_query=query,
|
||||
filters=filters,
|
||||
num_docs_fed_to_llm_selection=limit,
|
||||
run_query_expansion=False,
|
||||
include_content=True,
|
||||
stream=False,
|
||||
)
|
||||
# EE: use the dedicated search endpoint (no LLM invocation)
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
endpoint = f"{base_url}/search/send-search-message"
|
||||
error_key = "error"
|
||||
docs_key = "search_docs"
|
||||
content_field = "content"
|
||||
else:
|
||||
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
|
||||
request = SendMessageRequest(
|
||||
message=query,
|
||||
stream=False,
|
||||
chat_session_info=ChatSessionCreationRequest(),
|
||||
internal_search_filters=filters,
|
||||
)
|
||||
search_request = {
|
||||
"message": query,
|
||||
"stream": False,
|
||||
"chat_session_info": {},
|
||||
}
|
||||
if filters:
|
||||
search_request["internal_search_filters"] = filters
|
||||
endpoint = f"{base_url}/chat/send-chat-message"
|
||||
error_key = "error_msg"
|
||||
docs_key = "top_documents"
|
||||
content_field = "blurb"
|
||||
|
||||
try:
|
||||
response = await _post_model(
|
||||
response = await get_http_client().post(
|
||||
endpoint,
|
||||
request,
|
||||
access_token,
|
||||
timeout=None if is_ee else _CE_SEARCH_TIMEOUT_SECONDS,
|
||||
json=search_request,
|
||||
headers=auth_headers,
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": _extract_error_detail(response),
|
||||
"error": error_detail,
|
||||
}
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get(error_key):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get(error_key),
|
||||
}
|
||||
|
||||
if is_ee:
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
|
||||
ee_payload = SearchFullResponse.model_validate_json(response.content)
|
||||
if ee_payload.error:
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": ee_payload.error,
|
||||
}
|
||||
documents = [
|
||||
_project_doc(doc, doc.content) for doc in ee_payload.search_docs
|
||||
]
|
||||
else:
|
||||
ce_payload = ChatFullResponse.model_validate_json(response.content)
|
||||
if ce_payload.error_msg:
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": ce_payload.error_msg,
|
||||
}
|
||||
documents = [
|
||||
_project_doc(doc, doc.blurb) for doc in ce_payload.top_documents
|
||||
]
|
||||
documents = [
|
||||
{
|
||||
"semantic_identifier": doc.get("semantic_identifier"),
|
||||
"content": doc.get(content_field),
|
||||
"source_type": doc.get("source_type"),
|
||||
"link": doc.get("link"),
|
||||
"score": doc.get("score"),
|
||||
}
|
||||
for doc in result.get(docs_key, [])
|
||||
]
|
||||
|
||||
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
|
||||
# `limit` only caps the returned list; fewer results may be returned if the
|
||||
@@ -311,20 +252,23 @@ async def search_web(
|
||||
access_token = require_access_token()
|
||||
|
||||
try:
|
||||
response = await _post_model(
|
||||
request_payload = {"queries": [query], "max_results": limit}
|
||||
response = await get_http_client().post(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/search-lite",
|
||||
WebSearchToolRequest(queries=[query], max_results=limit),
|
||||
access_token,
|
||||
json=request_payload,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": _extract_error_detail(response),
|
||||
"error": error_detail,
|
||||
"results": [],
|
||||
"query": query,
|
||||
}
|
||||
payload = WebSearchToolResponse.model_validate_json(response.content)
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
return {
|
||||
"results": [result.model_dump(mode="json") for result in payload.results],
|
||||
"results": results,
|
||||
"query": query,
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -361,19 +305,21 @@ async def open_urls(
|
||||
access_token = require_access_token()
|
||||
|
||||
try:
|
||||
response = await _post_model(
|
||||
response = await get_http_client().post(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/open-urls",
|
||||
OpenUrlsToolRequest(urls=urls),
|
||||
access_token,
|
||||
json={"urls": urls},
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": _extract_error_detail(response),
|
||||
"error": error_detail,
|
||||
"results": [],
|
||||
}
|
||||
payload = OpenUrlsToolResponse.model_validate_json(response.content)
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
return {
|
||||
"results": [result.model_dump(mode="json") for result in payload.results],
|
||||
"results": results,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: URL fetch error: {e}", exc_info=True)
|
||||
|
||||
@@ -5,24 +5,10 @@ from __future__ import annotations
|
||||
import httpx
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from pydantic import BaseModel
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
|
||||
|
||||
class DocumentSetEntry(BaseModel):
|
||||
"""Minimal document-set shape surfaced to MCP clients.
|
||||
|
||||
Projected from the backend's DocumentSetSummary to avoid coupling MCP to
|
||||
admin-only fields (cc-pair summaries, federated connectors, etc.).
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Shared HTTP client reused across requests
|
||||
@@ -98,32 +84,3 @@ async def get_indexed_sources(
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(f"Failed to fetch indexed sources: {exc}") from exc
|
||||
|
||||
|
||||
_DOCUMENT_SET_ENTRIES_ADAPTER = TypeAdapter(list[DocumentSetEntry])
|
||||
|
||||
|
||||
async def get_accessible_document_sets(
|
||||
access_token: AccessToken,
|
||||
) -> list[DocumentSetEntry]:
|
||||
"""Fetch document sets accessible to the current user."""
|
||||
headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
try:
|
||||
response = await get_http_client().get(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/manage/document-set",
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return _DOCUMENT_SET_ENTRIES_ADAPTER.validate_json(response.content)
|
||||
except (httpx.HTTPStatusError, httpx.RequestError, ValueError):
|
||||
logger.error(
|
||||
"Onyx MCP Server: Failed to fetch document sets",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Onyx MCP Server: Unexpected error fetching document sets",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(f"Failed to fetch document sets: {exc}") from exc
|
||||
|
||||
@@ -5,7 +5,6 @@ import uvicorn
|
||||
from onyx.configs.app_configs import MCP_SERVER_ENABLED
|
||||
from onyx.configs.app_configs import MCP_SERVER_HOST
|
||||
from onyx.configs.app_configs import MCP_SERVER_PORT
|
||||
from onyx.tracing.setup import setup_tracing
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
@@ -19,7 +18,6 @@ def main() -> None:
|
||||
return
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
setup_tracing()
|
||||
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
|
||||
|
||||
from onyx.mcp_server.api import mcp_app
|
||||
|
||||
@@ -11,8 +11,6 @@ All public functions no-op in single-tenant mode (`MULTI_TENANT=False`).
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
@@ -28,40 +26,6 @@ logger = setup_logger()
|
||||
_SET_KEY = "active_tenants"
|
||||
|
||||
|
||||
# --- Prometheus metrics ---
|
||||
|
||||
_active_set_size = Gauge(
|
||||
"onyx_tenant_work_gating_active_set_size",
|
||||
"Current cardinality of the active_tenants sorted set (updated once per "
|
||||
"generator invocation when the gate reads it).",
|
||||
)
|
||||
|
||||
_marked_total = Counter(
|
||||
"onyx_tenant_work_gating_marked_total",
|
||||
"Writes into active_tenants, labelled by caller.",
|
||||
["caller"],
|
||||
)
|
||||
|
||||
_skipped_total = Counter(
|
||||
"onyx_tenant_work_gating_skipped_total",
|
||||
"Per-tenant fanouts skipped by the gate (enforce mode only), by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
_would_skip_total = Counter(
|
||||
"onyx_tenant_work_gating_would_skip_total",
|
||||
"Per-tenant fanouts that would have been skipped if enforce were on "
|
||||
"(shadow counter), by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
_full_fanout_total = Counter(
|
||||
"onyx_tenant_work_gating_full_fanout_total",
|
||||
"Generator invocations that bypassed the gate for a full fanout cycle, by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@@ -90,14 +54,10 @@ def mark_tenant_active(tenant_id: str) -> None:
|
||||
logger.exception(f"mark_tenant_active failed: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
def maybe_mark_tenant_active(tenant_id: str, caller: str = "unknown") -> None:
|
||||
def maybe_mark_tenant_active(tenant_id: str) -> None:
|
||||
"""Convenience wrapper for writer call sites: records the tenant only
|
||||
when the feature flag is on. Fully defensive — never raises, so a Redis
|
||||
outage or flag-read failure can't abort the calling task.
|
||||
|
||||
`caller` labels the Prometheus counter so a dashboard can show which
|
||||
consumer is firing the hook most.
|
||||
"""
|
||||
outage or flag-read failure can't abort the calling task."""
|
||||
try:
|
||||
# Local import to avoid a module-load cycle: OnyxRuntime imports
|
||||
# onyx.redis.redis_pool, so a top-level import here would wedge on
|
||||
@@ -107,44 +67,10 @@ def maybe_mark_tenant_active(tenant_id: str, caller: str = "unknown") -> None:
|
||||
if not OnyxRuntime.get_tenant_work_gating_enabled():
|
||||
return
|
||||
mark_tenant_active(tenant_id)
|
||||
_marked_total.labels(caller=caller).inc()
|
||||
except Exception:
|
||||
logger.exception(f"maybe_mark_tenant_active failed: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
def observe_active_set_size() -> int | None:
|
||||
"""Return `ZCARD active_tenants` and update the Prometheus gauge. Call
|
||||
from the gate generator once per invocation so the dashboard has a
|
||||
live reading.
|
||||
|
||||
Returns `None` on Redis error or in single-tenant mode; callers can
|
||||
tolerate that (gauge simply doesn't update)."""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
try:
|
||||
size = cast(int, _client().zcard(_SET_KEY))
|
||||
_active_set_size.set(size)
|
||||
return size
|
||||
except Exception:
|
||||
logger.exception("observe_active_set_size failed")
|
||||
return None
|
||||
|
||||
|
||||
def record_gate_decision(task_name: str, skipped: bool) -> None:
|
||||
"""Increment skip counters from the gate generator. Called once per
|
||||
tenant that the gate would skip. Always increments the shadow counter;
|
||||
increments the enforced counter only when `skipped=True`."""
|
||||
_would_skip_total.labels(task=task_name).inc()
|
||||
if skipped:
|
||||
_skipped_total.labels(task=task_name).inc()
|
||||
|
||||
|
||||
def record_full_fanout_cycle(task_name: str) -> None:
|
||||
"""Increment the full-fanout counter. Called once per generator
|
||||
invocation where the gate is bypassed (interval elapsed OR fail-open)."""
|
||||
_full_fanout_total.labels(task=task_name).inc()
|
||||
|
||||
|
||||
def get_active_tenants(ttl_seconds: int) -> set[str] | None:
|
||||
"""Return tenants whose last-seen timestamp is within `ttl_seconds` of
|
||||
now.
|
||||
|
||||
@@ -584,7 +584,7 @@ def associate_credential_to_connector(
|
||||
|
||||
# Tenant-work-gating lifecycle hook: keep new-tenant latency to
|
||||
# seconds instead of one full-fanout interval.
|
||||
maybe_mark_tenant_active(tenant_id, caller="cc_pair_lifecycle")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# trigger indexing immediately
|
||||
client_app.send_task(
|
||||
|
||||
@@ -1643,7 +1643,7 @@ def create_connector_with_mock_credential(
|
||||
|
||||
# Tenant-work-gating lifecycle hook: keep new-tenant latency to
|
||||
# seconds instead of one full-fanout interval.
|
||||
maybe_mark_tenant_active(tenant_id, caller="cc_pair_lifecycle")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# trigger indexing immediately
|
||||
client_app.send_task(
|
||||
|
||||
@@ -113,7 +113,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set
|
||||
# membership whenever sandbox cleanup has work to do.
|
||||
maybe_mark_tenant_active(tenant_id, caller="sandbox_cleanup")
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
task_logger.info(
|
||||
f"Found {len(idle_sandboxes)} idle sandboxes to put to sleep"
|
||||
|
||||
@@ -11,9 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import count_docx_embedded_images
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -200,9 +198,6 @@ def categorize_uploaded_files(
|
||||
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
|
||||
batch_image_total = 0
|
||||
|
||||
# Hoisted out of the loop to avoid a KV-store lookup per file.
|
||||
image_extraction_enabled = get_image_extraction_and_analysis_enabled()
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
@@ -265,33 +260,28 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject documents with an unreasonable number of embedded
|
||||
# images (either per-file or accumulated across this upload
|
||||
# batch). A file with thousands of embedded images can OOM the
|
||||
# Reject PDFs with an unreasonable number of embedded images
|
||||
# (either per-file or accumulated across this upload batch).
|
||||
# A PDF with thousands of embedded images can OOM the
|
||||
# user-file-processing celery worker because every image is
|
||||
# decoded with PIL and then sent to the vision LLM.
|
||||
count: int = 0
|
||||
image_bearing_ext = extension in (".pdf", ".docx")
|
||||
if image_bearing_ext:
|
||||
if extension == ".pdf":
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Use the larger of the two caps as the short-circuit
|
||||
# threshold so we get a useful count for both checks.
|
||||
# These helpers restore the stream position.
|
||||
counter = (
|
||||
count_pdf_embedded_images
|
||||
if extension == ".pdf"
|
||||
else count_docx_embedded_images
|
||||
# count_pdf_embedded_images restores the stream position.
|
||||
count = count_pdf_embedded_images(
|
||||
upload.file, max(file_cap, batch_cap)
|
||||
)
|
||||
count = counter(upload.file, max(file_cap, batch_cap))
|
||||
if count > file_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"Document contains too many embedded "
|
||||
f"images (more than {file_cap}). Try "
|
||||
f"splitting it into smaller files."
|
||||
f"PDF contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting "
|
||||
f"the document into smaller files."
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -318,21 +308,6 @@ def categorize_uploaded_files(
|
||||
extension=extension,
|
||||
)
|
||||
if not text_content:
|
||||
# Documents with embedded images (e.g. scans) have no
|
||||
# extractable text but can still be indexed via the
|
||||
# vision-LLM captioning path when image analysis is
|
||||
# enabled.
|
||||
if image_bearing_ext and count > 0 and image_extraction_enabled:
|
||||
results.acceptable.append(upload)
|
||||
results.acceptable_file_to_token_count[filename] = 0
|
||||
try:
|
||||
upload.file.seek(0)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to reset file pointer for '{filename}': {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.warning(f"No text content extracted from '{filename}'")
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
|
||||
@@ -3,7 +3,6 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
@@ -50,7 +49,6 @@ def get_opensearch_retrieval_status(
|
||||
enable_opensearch_retrieval = get_opensearch_retrieval_state(db_session)
|
||||
return OpenSearchRetrievalStatusResponse(
|
||||
enable_opensearch_retrieval=enable_opensearch_retrieval,
|
||||
toggling_retrieval_is_disabled=ONYX_DISABLE_VESPA,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,5 +63,4 @@ def set_opensearch_retrieval_status(
|
||||
)
|
||||
return OpenSearchRetrievalStatusResponse(
|
||||
enable_opensearch_retrieval=request.enable_opensearch_retrieval,
|
||||
toggling_retrieval_is_disabled=ONYX_DISABLE_VESPA,
|
||||
)
|
||||
|
||||
@@ -19,4 +19,3 @@ class OpenSearchRetrievalStatusRequest(BaseModel):
|
||||
class OpenSearchRetrievalStatusResponse(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
enable_opensearch_retrieval: bool
|
||||
toggling_retrieval_is_disabled: bool = False
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
"""Prometheus metrics for embedding generation latency and throughput.
|
||||
|
||||
Tracks client-side round-trip latency (as seen by callers of
|
||||
``EmbeddingModel.encode``) and server-side execution time (as measured inside
|
||||
the model server for the local-model path). Both API-provider and local-model
|
||||
paths flow through the client-side metric; only the local path populates the
|
||||
server-side metric.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LOCAL_PROVIDER_LABEL = "local"
|
||||
|
||||
_EMBEDDING_LATENCY_BUCKETS = (
|
||||
0.005,
|
||||
0.01,
|
||||
0.025,
|
||||
0.05,
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
1.0,
|
||||
2.5,
|
||||
5.0,
|
||||
10.0,
|
||||
25.0,
|
||||
)
|
||||
|
||||
PROVIDER_LABEL_NAME = "provider"
|
||||
TEXT_TYPE_LABEL_NAME = "text_type"
|
||||
STATUS_LABEL_NAME = "status"
|
||||
|
||||
_client_duration = Histogram(
|
||||
"onyx_embedding_client_duration_seconds",
|
||||
"Client-side end-to-end latency of an embedding batch as seen by the caller.",
|
||||
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME],
|
||||
buckets=_EMBEDDING_LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
_embedding_requests_total = Counter(
|
||||
"onyx_embedding_requests_total",
|
||||
"Total embedding batch requests, labeled by outcome.",
|
||||
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME, STATUS_LABEL_NAME],
|
||||
)
|
||||
|
||||
_embedding_texts_total = Counter(
|
||||
"onyx_embedding_texts_total",
|
||||
"Total number of individual texts submitted for embedding.",
|
||||
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME],
|
||||
)
|
||||
|
||||
_embedding_input_chars_total = Counter(
|
||||
"onyx_embedding_input_chars_total",
|
||||
"Total number of input characters submitted for embedding.",
|
||||
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME],
|
||||
)
|
||||
|
||||
_embeddings_in_progress = Gauge(
|
||||
"onyx_embeddings_in_progress",
|
||||
"Number of embedding batches currently in-flight.",
|
||||
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME],
|
||||
)
|
||||
|
||||
|
||||
def provider_label(provider: EmbeddingProvider | None) -> str:
|
||||
if provider is None:
|
||||
return LOCAL_PROVIDER_LABEL
|
||||
return provider.value
|
||||
|
||||
|
||||
def observe_embedding_client(
|
||||
provider: EmbeddingProvider | None,
|
||||
text_type: EmbedTextType,
|
||||
duration_s: float,
|
||||
num_texts: int,
|
||||
num_chars: int,
|
||||
success: bool,
|
||||
) -> None:
|
||||
"""Records a completed embedding batch.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider, or ``None`` for the local model path.
|
||||
text_type: Whether this was a query- or passage-style embedding.
|
||||
duration_s: Wall-clock duration measured on the client side, in seconds.
|
||||
num_texts: Number of texts in the batch.
|
||||
num_chars: Total number of input characters in the batch.
|
||||
success: Whether the embedding call succeeded.
|
||||
"""
|
||||
try:
|
||||
provider_lbl = provider_label(provider)
|
||||
text_type_lbl = text_type.value
|
||||
status_lbl = "success" if success else "failure"
|
||||
|
||||
_embedding_requests_total.labels(
|
||||
provider=provider_lbl, text_type=text_type_lbl, status=status_lbl
|
||||
).inc()
|
||||
_client_duration.labels(provider=provider_lbl, text_type=text_type_lbl).observe(
|
||||
duration_s
|
||||
)
|
||||
if success:
|
||||
_embedding_texts_total.labels(
|
||||
provider=provider_lbl, text_type=text_type_lbl
|
||||
).inc(num_texts)
|
||||
_embedding_input_chars_total.labels(
|
||||
provider=provider_lbl, text_type=text_type_lbl
|
||||
).inc(num_chars)
|
||||
except Exception:
|
||||
logger.warning("Failed to record embedding client metrics.", exc_info=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def track_embedding_in_progress(
|
||||
provider: EmbeddingProvider | None,
|
||||
text_type: EmbedTextType,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager that tracks in-flight embedding batches via a Gauge."""
|
||||
incremented = False
|
||||
provider_lbl = provider_label(provider)
|
||||
text_type_lbl = text_type.value
|
||||
try:
|
||||
_embeddings_in_progress.labels(
|
||||
provider=provider_lbl, text_type=text_type_lbl
|
||||
).inc()
|
||||
incremented = True
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to increment in-progress embedding gauge.", exc_info=True
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if incremented:
|
||||
try:
|
||||
_embeddings_in_progress.labels(
|
||||
provider=provider_lbl, text_type=text_type_lbl
|
||||
).dec()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decrement in-progress embedding gauge.", exc_info=True
|
||||
)
|
||||
@@ -395,15 +395,6 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
|
||||
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
|
||||
to the Celery event stream via a single persistent connection.
|
||||
|
||||
TODO: every monitoring pod subscribes to the cluster-wide Celery event
|
||||
stream, so each replica reports health for *all* workers in the cluster,
|
||||
not just itself. Prometheus distinguishes the replicas via the ``instance``
|
||||
label, so this doesn't break scraping, but it means N monitoring replicas
|
||||
do N× the work and may emit slightly inconsistent snapshots of the same
|
||||
cluster. The proper fix is to have each worker expose its own health (or
|
||||
to elect a single monitoring replica as the reporter) rather than
|
||||
broadcasting the full cluster view from every monitoring pod.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = 30.0) -> None:
|
||||
@@ -422,16 +413,10 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers with recent heartbeats",
|
||||
)
|
||||
# Celery hostnames are ``{worker_type}@{nodename}`` (see supervisord.conf).
|
||||
# Emitting only the worker_type as a label causes N replicas of the same
|
||||
# type to collapse into identical timeseries within a single scrape,
|
||||
# which Prometheus rejects as "duplicate sample for timestamp". Split
|
||||
# the pieces into separate labels so each replica is distinct; callers
|
||||
# can still ``sum by (worker_type)`` to recover the old aggregated view.
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
"Whether a specific Celery worker is alive (1=up, 0=down)",
|
||||
labels=["worker_type", "hostname"],
|
||||
labels=["worker"],
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -439,15 +424,11 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
alive_count = sum(1 for alive in status.values() if alive)
|
||||
active_workers.add_metric([], alive_count)
|
||||
|
||||
for full_hostname in sorted(status):
|
||||
worker_type, sep, host = full_hostname.partition("@")
|
||||
if not sep:
|
||||
# Hostname didn't contain "@" — fall back to using the
|
||||
# whole string as the hostname with an empty type.
|
||||
worker_type, host = "", full_hostname
|
||||
worker_up.add_metric(
|
||||
[worker_type, host], 1 if status[full_hostname] else 0
|
||||
)
|
||||
for hostname in sorted(status):
|
||||
# Use short name (before @) for single-host deployments,
|
||||
# full hostname when multiple hosts share a worker type.
|
||||
label = hostname.split("@")[0]
|
||||
worker_up.add_metric([label], 1 if status[hostname] else 0)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
|
||||
@@ -63,7 +63,6 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.db.user_file import user_can_access_chat_file
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -867,18 +866,14 @@ def seed_chat_from_slack(
|
||||
def fetch_chat_file(
|
||||
file_id: str,
|
||||
request: Request,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
|
||||
# For user files, we need to get the file id from the user file id
|
||||
file_id_from_user_file = get_file_id_by_user_file_id(file_id, user.id, db_session)
|
||||
file_id_from_user_file = get_file_id_by_user_file_id(file_id, db_session)
|
||||
if file_id_from_user_file:
|
||||
file_id = file_id_from_user_file
|
||||
elif not user_can_access_chat_file(file_id, user.id, db_session):
|
||||
# Return 404 (rather than 403) so callers cannot probe for file
|
||||
# existence across ownership boundaries.
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "File not found")
|
||||
|
||||
file_store = get_default_file_store()
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
|
||||
@@ -80,7 +80,7 @@ class Settings(BaseModel):
|
||||
query_history_type: QueryHistoryType | None = None
|
||||
|
||||
# Image processing settings
|
||||
image_extraction_and_analysis_enabled: bool | None = True
|
||||
image_extraction_and_analysis_enabled: bool | None = False
|
||||
search_time_image_analysis_enabled: bool | None = False
|
||||
image_analysis_max_size_mb: int | None = 20
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.embedding_configs import SUPPORTED_EMBEDDING_MODELS
|
||||
@@ -127,11 +126,10 @@ def setup_onyx(
|
||||
"DISABLE_VECTOR_DB is set — skipping document index setup and embedding model warm-up."
|
||||
)
|
||||
else:
|
||||
# Ensure the document indices are setup correctly. This step is
|
||||
# relatively near the end because Vespa takes a bit of time to start up.
|
||||
# Ensure Vespa is setup correctly, this step is relatively near the end
|
||||
# because Vespa takes a bit of time to start up
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
# This flow is for setting up the document index so we get all indices
|
||||
# here.
|
||||
# This flow is for setting up the document index so we get all indices here.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings,
|
||||
secondary_search_settings,
|
||||
@@ -337,7 +335,7 @@ def setup_multitenant_onyx() -> None:
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA and not ONYX_DISABLE_VESPA:
|
||||
if not MANAGED_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ stop_and_remove_containers
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
if [[ -n "$POSTGRES_VOLUME" ]]; then
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d -v "$POSTGRES_VOLUME":/var/lib/postgresql/data postgres -c max_connections=250
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres -c max_connections=250
|
||||
else
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d postgres -c max_connections=250
|
||||
fi
|
||||
@@ -54,7 +54,7 @@ fi
|
||||
# Start the Vespa container with optional volume
|
||||
echo "Starting Vespa container..."
|
||||
if [[ -n "$VESPA_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v "$VESPA_VOLUME":/opt/vespa/var vespaengine/vespa:8
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v $VESPA_VOLUME:/opt/vespa/var vespaengine/vespa:8
|
||||
else
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
|
||||
fi
|
||||
@@ -85,7 +85,7 @@ docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-en
|
||||
# Start the Redis container with optional volume
|
||||
echo "Starting Redis container..."
|
||||
if [[ -n "$REDIS_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 -v "$REDIS_VOLUME":/data redis
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 -v $REDIS_VOLUME:/data redis
|
||||
else
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 redis
|
||||
fi
|
||||
@@ -93,7 +93,7 @@ fi
|
||||
# Start the MinIO container with optional volume
|
||||
echo "Starting MinIO container..."
|
||||
if [[ -n "$MINIO_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin -v "$MINIO_VOLUME":/data minio/minio server /data --console-address ":9001"
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin -v $MINIO_VOLUME:/data minio/minio server /data --console-address ":9001"
|
||||
else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
@@ -111,7 +111,6 @@ sleep 1
|
||||
|
||||
# Alembic should be configured in the virtualenv for this repo
|
||||
if [[ -f "../.venv/bin/activate" ]]; then
|
||||
# shellcheck source=/dev/null
|
||||
source ../.venv/bin/activate
|
||||
else
|
||||
echo "Warning: Python virtual environment not found at .venv/bin/activate; alembic may not work."
|
||||
|
||||
@@ -99,14 +99,6 @@ STRICT_CHUNK_TOKEN_LIMIT = (
|
||||
# Set up Sentry integration (for error logging)
|
||||
SENTRY_DSN = os.environ.get("SENTRY_DSN")
|
||||
|
||||
# Celery task spans dominate ingestion volume (~94%), so default celery
|
||||
# tracing to 0. Web/API traces stay at a small non-zero rate so http.server
|
||||
# traces remain available. Both are env-tunable without a code change.
|
||||
SENTRY_TRACES_SAMPLE_RATE = float(os.environ.get("SENTRY_TRACES_SAMPLE_RATE", "0.01"))
|
||||
SENTRY_CELERY_TRACES_SAMPLE_RATE = float(
|
||||
os.environ.get("SENTRY_CELERY_TRACES_SAMPLE_RATE", "0.0")
|
||||
)
|
||||
|
||||
|
||||
# Fields which should only be set on new search setting
|
||||
PRESERVED_SEARCH_FIELDS = [
|
||||
|
||||
@@ -9,10 +9,8 @@ import pytest
|
||||
|
||||
from onyx.configs.constants import BlobType
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -113,18 +111,15 @@ def test_blob_s3_connector(
|
||||
|
||||
for doc in all_docs:
|
||||
section = doc.sections[0]
|
||||
|
||||
if is_tabular_file(doc.semantic_identifier):
|
||||
assert isinstance(section, TabularSection)
|
||||
assert len(section.text) > 0
|
||||
continue
|
||||
|
||||
assert isinstance(section, TextSection)
|
||||
|
||||
file_extension = get_file_ext(doc.semantic_identifier)
|
||||
if file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
|
||||
assert len(section.text) > 0
|
||||
else:
|
||||
assert len(section.text) == 0
|
||||
continue
|
||||
|
||||
# unknown extension
|
||||
assert len(section.text) == 0
|
||||
|
||||
|
||||
@patch(
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
"""External dependency tests for onyx.file_store.staging.
|
||||
|
||||
Exercises the raw-file persistence hook used by the docfetching pipeline
|
||||
against a real file store (Postgres + MinIO/S3), since mocking the store
|
||||
would defeat the point of verifying that metadata round-trips through
|
||||
FileRecord.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.db.file_record import delete_filerecord_by_file_id
|
||||
from onyx.db.file_record import get_filerecord_by_file_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.staging import build_raw_file_callback
|
||||
from onyx.file_store.staging import stage_raw_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def cleanup_file_ids(
|
||||
db_session: Session,
|
||||
) -> Generator[list[str], None, None]:
|
||||
created: list[str] = []
|
||||
yield created
|
||||
file_store = get_default_file_store()
|
||||
for fid in created:
|
||||
try:
|
||||
file_store.delete_file(fid)
|
||||
except Exception:
|
||||
delete_filerecord_by_file_id(file_id=fid, db_session=db_session)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def test_stage_raw_file_persists_with_origin_and_metadata(
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG001
|
||||
initialize_file_store: None, # noqa: ARG001
|
||||
cleanup_file_ids: list[str],
|
||||
) -> None:
|
||||
"""stage_raw_file writes a FileRecord with INDEXING_STAGING origin and
|
||||
round-trips the provided metadata verbatim."""
|
||||
metadata: dict[str, Any] = {
|
||||
"index_attempt_id": 42,
|
||||
"cc_pair_id": 7,
|
||||
"tenant_id": "tenant-abc",
|
||||
"extra": "payload",
|
||||
}
|
||||
content_bytes = b"hello raw file"
|
||||
content_type = "application/pdf"
|
||||
|
||||
file_id = stage_raw_file(
|
||||
content=BytesIO(content_bytes),
|
||||
content_type=content_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
cleanup_file_ids.append(file_id)
|
||||
db_session.commit()
|
||||
|
||||
record = get_filerecord_by_file_id(file_id=file_id, db_session=db_session)
|
||||
assert record.file_origin == FileOrigin.INDEXING_STAGING
|
||||
assert record.file_type == content_type
|
||||
assert record.file_metadata == metadata
|
||||
|
||||
|
||||
def test_build_raw_file_callback_binds_attempt_context_per_call(
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG001
|
||||
initialize_file_store: None, # noqa: ARG001
|
||||
cleanup_file_ids: list[str],
|
||||
) -> None:
|
||||
"""The callback returned by build_raw_file_callback must bind the
|
||||
attempt-level context into every FileRecord it produces, without
|
||||
leaking state across invocations."""
|
||||
callback = build_raw_file_callback(
|
||||
index_attempt_id=1001,
|
||||
cc_pair_id=202,
|
||||
tenant_id="tenant-xyz",
|
||||
)
|
||||
|
||||
file_id_a = callback(BytesIO(b"alpha"), "text/plain")
|
||||
file_id_b = callback(BytesIO(b"beta"), "application/octet-stream")
|
||||
cleanup_file_ids.extend([file_id_a, file_id_b])
|
||||
db_session.commit()
|
||||
|
||||
assert file_id_a != file_id_b
|
||||
|
||||
for fid, expected_content_type in (
|
||||
(file_id_a, "text/plain"),
|
||||
(file_id_b, "application/octet-stream"),
|
||||
):
|
||||
record = get_filerecord_by_file_id(file_id=fid, db_session=db_session)
|
||||
assert record.file_origin == FileOrigin.INDEXING_STAGING
|
||||
assert record.file_type == expected_content_type
|
||||
assert record.file_metadata == {
|
||||
"index_attempt_id": 1001,
|
||||
"cc_pair_id": 202,
|
||||
"tenant_id": "tenant-xyz",
|
||||
}
|
||||
|
||||
|
||||
def test_set_raw_file_callback_on_base_connector() -> None:
|
||||
"""set_raw_file_callback must install the callback as an instance
|
||||
attribute usable by the connector."""
|
||||
|
||||
class _MinimalConnector(BaseConnector):
|
||||
def load_credentials(
|
||||
self,
|
||||
credentials: dict[str, Any], # noqa: ARG002
|
||||
) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
connector = _MinimalConnector()
|
||||
assert connector.raw_file_callback is None
|
||||
|
||||
sentinel_file_id = f"sentinel-{uuid4().hex[:8]}"
|
||||
|
||||
def _fake_callback(_content: Any, _content_type: str) -> str:
|
||||
return sentinel_file_id
|
||||
|
||||
connector.set_raw_file_callback(_fake_callback)
|
||||
|
||||
assert connector.raw_file_callback is _fake_callback
|
||||
assert connector.raw_file_callback(BytesIO(b""), "text/plain") == sentinel_file_id
|
||||
@@ -1,405 +0,0 @@
|
||||
"""External dependency unit tests for `index_doc_batch_prepare`.
|
||||
|
||||
Validates the file_id lifecycle that runs alongside the document upsert:
|
||||
|
||||
* `document.file_id` is written on insert AND on conflict (upsert path)
|
||||
* Newly-staged files get promoted from INDEXING_STAGING -> CONNECTOR
|
||||
* Replaced files are deleted from both `file_record` and S3
|
||||
* No-op when the file_id is unchanged
|
||||
|
||||
Uses real PostgreSQL + real S3/MinIO via the file store.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.file_record import get_filerecord_by_file_id_optional
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_doc(doc_id: str, file_id: str | None = None) -> Document:
|
||||
"""Minimal Document for indexing-pipeline tests. MOCK_CONNECTOR avoids
|
||||
triggering the hierarchy-node linking branch (NOTION/CONFLUENCE only)."""
|
||||
return Document(
|
||||
id=doc_id,
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
semantic_identifier=f"semantic-{doc_id}",
|
||||
sections=[TextSection(text="content", link=None)],
|
||||
metadata={},
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
|
||||
def _stage_file(content: bytes = b"raw bytes") -> str:
|
||||
"""Write bytes to the file store as INDEXING_STAGING and return the file_id.
|
||||
|
||||
Mirrors what the connector raw_file_callback would do during fetch.
|
||||
"""
|
||||
return get_default_file_store().save_file(
|
||||
content=BytesIO(content),
|
||||
display_name=None,
|
||||
file_origin=FileOrigin.INDEXING_STAGING,
|
||||
file_type="application/octet-stream",
|
||||
file_metadata={"test": True},
|
||||
)
|
||||
|
||||
|
||||
def _get_doc_row(db_session: Session, doc_id: str) -> DBDocument | None:
|
||||
"""Reload the document row fresh from DB so we see post-upsert state."""
|
||||
db_session.expire_all()
|
||||
return db_session.query(DBDocument).filter(DBDocument.id == doc_id).one_or_none()
|
||||
|
||||
|
||||
def _get_filerecord(db_session: Session, file_id: str) -> FileRecord | None:
|
||||
db_session.expire_all()
|
||||
return get_filerecord_by_file_id_optional(file_id=file_id, db_session=db_session)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cc_pair(
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG001
|
||||
initialize_file_store: None, # noqa: ARG001
|
||||
) -> Generator[ConnectorCredentialPair, None, None]:
|
||||
"""Create a connector + credential + cc_pair backing the index attempt.
|
||||
|
||||
Teardown sweeps everything the test created under this cc_pair: the
|
||||
`document_by_connector_credential_pair` join rows, the `Document` rows
|
||||
they point at, the `FileRecord` + blob for each doc's `file_id`, and
|
||||
finally the cc_pair / connector / credential themselves. Without this,
|
||||
every run would leave orphan rows in the dev DB and orphan blobs in
|
||||
MinIO.
|
||||
"""
|
||||
connector = Connector(
|
||||
name=f"test-connector-{uuid4().hex[:8]}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.flush()
|
||||
|
||||
credential = Credential(
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
credential_json={},
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
name=f"test-cc-pair-{uuid4().hex[:8]}",
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
access_type=AccessType.PUBLIC,
|
||||
auto_sync_options=None,
|
||||
)
|
||||
db_session.add(pair)
|
||||
db_session.commit()
|
||||
db_session.refresh(pair)
|
||||
|
||||
connector_id = pair.connector_id
|
||||
credential_id = pair.credential_id
|
||||
|
||||
try:
|
||||
yield pair
|
||||
finally:
|
||||
db_session.expire_all()
|
||||
|
||||
# Collect every doc indexed under this cc_pair so we can delete its
|
||||
# file_record + blob before dropping the Document row itself.
|
||||
doc_ids: list[str] = [
|
||||
row[0]
|
||||
for row in db_session.query(DocumentByConnectorCredentialPair.id)
|
||||
.filter(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
.all()
|
||||
]
|
||||
file_ids: list[str] = [
|
||||
row[0]
|
||||
for row in db_session.query(DBDocument.file_id)
|
||||
.filter(DBDocument.id.in_(doc_ids), DBDocument.file_id.isnot(None))
|
||||
.all()
|
||||
]
|
||||
|
||||
file_store = get_default_file_store()
|
||||
for fid in file_ids:
|
||||
try:
|
||||
file_store.delete_file(fid, error_on_missing=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if doc_ids:
|
||||
db_session.query(DocumentByConnectorCredentialPair).filter(
|
||||
DocumentByConnectorCredentialPair.id.in_(doc_ids)
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(DBDocument).filter(DBDocument.id.in_(doc_ids)).delete(
|
||||
synchronize_session="fetch"
|
||||
)
|
||||
|
||||
db_session.query(ConnectorCredentialPair).filter(
|
||||
ConnectorCredentialPair.id == pair.id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Connector).filter(Connector.id == connector_id).delete(
|
||||
synchronize_session="fetch"
|
||||
)
|
||||
db_session.query(Credential).filter(Credential.id == credential_id).delete(
|
||||
synchronize_session="fetch"
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def attempt_metadata(cc_pair: ConnectorCredentialPair) -> IndexAttemptMetadata:
|
||||
return IndexAttemptMetadata(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
attempt_id=None,
|
||||
request_id="test-request",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNewDocuments:
|
||||
"""First-time inserts — no previous file_id to reconcile against."""
|
||||
|
||||
def test_new_doc_without_file_id(
|
||||
self,
|
||||
db_session: Session,
|
||||
attempt_metadata: IndexAttemptMetadata,
|
||||
) -> None:
|
||||
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=None)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
documents=[doc],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
row = _get_doc_row(db_session, doc.id)
|
||||
assert row is not None
|
||||
assert row.file_id is None
|
||||
|
||||
def test_new_doc_with_staged_file_id_promotes_to_connector(
|
||||
self,
|
||||
db_session: Session,
|
||||
attempt_metadata: IndexAttemptMetadata,
|
||||
) -> None:
|
||||
file_id = _stage_file()
|
||||
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=file_id)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
documents=[doc],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
row = _get_doc_row(db_session, doc.id)
|
||||
assert row is not None and row.file_id == file_id
|
||||
|
||||
record = _get_filerecord(db_session, file_id)
|
||||
assert record is not None
|
||||
assert record.file_origin == FileOrigin.CONNECTOR
|
||||
|
||||
|
||||
class TestExistingDocuments:
|
||||
"""Re-index path — a `document` row already exists with some file_id."""
|
||||
|
||||
def test_unchanged_file_id_is_noop(
|
||||
self,
|
||||
db_session: Session,
|
||||
attempt_metadata: IndexAttemptMetadata,
|
||||
) -> None:
|
||||
file_id = _stage_file()
|
||||
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=file_id)
|
||||
|
||||
# First pass: inserts the row + promotes the file.
|
||||
index_doc_batch_prepare(
|
||||
documents=[doc],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Second pass with the same file_id — should not delete or re-promote.
|
||||
index_doc_batch_prepare(
|
||||
documents=[doc],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
record = _get_filerecord(db_session, file_id)
|
||||
assert record is not None
|
||||
assert record.file_origin == FileOrigin.CONNECTOR
|
||||
|
||||
row = _get_doc_row(db_session, doc.id)
|
||||
assert row is not None and row.file_id == file_id
|
||||
|
||||
def test_swapping_file_id_promotes_new_and_deletes_old(
|
||||
self,
|
||||
db_session: Session,
|
||||
attempt_metadata: IndexAttemptMetadata,
|
||||
) -> None:
|
||||
old_file_id = _stage_file(content=b"old bytes")
|
||||
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=old_file_id)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
documents=[doc],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Re-fetch produces a new staged file_id for the same doc.
|
||||
new_file_id = _stage_file(content=b"new bytes")
|
||||
doc_v2 = _make_doc(doc.id, file_id=new_file_id)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
documents=[doc_v2],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
row = _get_doc_row(db_session, doc.id)
|
||||
assert row is not None and row.file_id == new_file_id
|
||||
|
||||
new_record = _get_filerecord(db_session, new_file_id)
|
||||
assert new_record is not None
|
||||
assert new_record.file_origin == FileOrigin.CONNECTOR
|
||||
|
||||
# Old file_record + S3 object are gone.
|
||||
assert _get_filerecord(db_session, old_file_id) is None
|
||||
|
||||
def test_clearing_file_id_deletes_old_and_nulls_column(
|
||||
self,
|
||||
db_session: Session,
|
||||
attempt_metadata: IndexAttemptMetadata,
|
||||
) -> None:
|
||||
old_file_id = _stage_file()
|
||||
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=old_file_id)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
documents=[doc],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Connector opts out on next run — yields the doc without a file_id.
|
||||
doc_v2 = _make_doc(doc.id, file_id=None)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
documents=[doc_v2],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
row = _get_doc_row(db_session, doc.id)
|
||||
assert row is not None and row.file_id is None
|
||||
assert _get_filerecord(db_session, old_file_id) is None
|
||||
|
||||
|
||||
class TestBatchHandling:
|
||||
"""Mixed batches — multiple docs at different lifecycle states in one call."""
|
||||
|
||||
def test_mixed_batch_each_doc_handled_independently(
|
||||
self,
|
||||
db_session: Session,
|
||||
attempt_metadata: IndexAttemptMetadata,
|
||||
) -> None:
|
||||
# Pre-seed an existing doc with a file_id we'll swap.
|
||||
existing_old_id = _stage_file(content=b"existing-old")
|
||||
existing_doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=existing_old_id)
|
||||
index_doc_batch_prepare(
|
||||
documents=[existing_doc],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Now: swap the existing one, add a brand-new doc with file_id, and a
|
||||
# brand-new doc without file_id.
|
||||
swap_new_id = _stage_file(content=b"existing-new")
|
||||
new_with_file_id = _stage_file(content=b"new-with-file")
|
||||
existing_v2 = _make_doc(existing_doc.id, file_id=swap_new_id)
|
||||
new_with = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=new_with_file_id)
|
||||
new_without = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=None)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
documents=[existing_v2, new_with, new_without],
|
||||
index_attempt_metadata=attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Existing doc was swapped: old file gone, new file promoted.
|
||||
existing_row = _get_doc_row(db_session, existing_doc.id)
|
||||
assert existing_row is not None and existing_row.file_id == swap_new_id
|
||||
assert _get_filerecord(db_session, existing_old_id) is None
|
||||
swap_record = _get_filerecord(db_session, swap_new_id)
|
||||
assert swap_record is not None
|
||||
assert swap_record.file_origin == FileOrigin.CONNECTOR
|
||||
|
||||
# New doc with file_id: row exists, file promoted.
|
||||
new_with_row = _get_doc_row(db_session, new_with.id)
|
||||
assert new_with_row is not None and new_with_row.file_id == new_with_file_id
|
||||
new_with_record = _get_filerecord(db_session, new_with_file_id)
|
||||
assert new_with_record is not None
|
||||
assert new_with_record.file_origin == FileOrigin.CONNECTOR
|
||||
|
||||
# New doc without file_id: row exists, no file_record involvement.
|
||||
new_without_row = _get_doc_row(db_session, new_without.id)
|
||||
assert new_without_row is not None and new_without_row.file_id is None
|
||||
@@ -446,107 +446,10 @@ class TestOpenSearchClient:
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests updating index settings on an existing index."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
# Assert that the current number of replicas is not the desired test
|
||||
# number we are updating to.
|
||||
test_num_replicas = 0
|
||||
current_settings = test_client.get_settings()
|
||||
assert current_settings["index"]["number_of_replicas"] != f"{test_num_replicas}"
|
||||
|
||||
# Under test.
|
||||
# Should not raise. number_of_replicas is a dynamic setting that can be
|
||||
# changed without closing the index.
|
||||
test_client.update_settings(
|
||||
settings={"index": {"number_of_replicas": test_num_replicas}}
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
current_settings = test_client.get_settings()
|
||||
assert current_settings["index"]["number_of_replicas"] == f"{test_num_replicas}"
|
||||
|
||||
def test_update_settings_on_nonexistent_index(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests updating settings on a nonexistent index raises an error."""
|
||||
"""Tests that update_settings raises NotImplementedError."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.update_settings(settings={"index": {"number_of_replicas": 0}})
|
||||
|
||||
def test_get_settings(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests getting index settings."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
current_settings = test_client.get_settings()
|
||||
|
||||
# Postcondition.
|
||||
assert "index" in current_settings
|
||||
# These are always present for any index.
|
||||
assert "number_of_shards" in current_settings["index"]
|
||||
assert "number_of_replicas" in current_settings["index"]
|
||||
assert current_settings["index"]["provided_name"] == test_client._index_name
|
||||
|
||||
def test_get_settings_on_nonexistent_index(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests getting settings on a nonexistent index raises an error."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.get_settings()
|
||||
|
||||
def test_close_and_open_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests closing and reopening an index."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Closing should not raise.
|
||||
test_client.close_index()
|
||||
|
||||
# Postcondition.
|
||||
# Searches on a closed index should fail.
|
||||
with pytest.raises(Exception, match="index_closed_exception|closed"):
|
||||
test_client.search_for_document_ids(
|
||||
body={"_source": False, "query": {"match_all": {}}}
|
||||
)
|
||||
|
||||
# Under test.
|
||||
# Reopening should not raise.
|
||||
test_client.open_index()
|
||||
|
||||
# Postcondition.
|
||||
# Searches should work again after reopening.
|
||||
result = test_client.search_for_document_ids(
|
||||
body={"_source": False, "query": {"match_all": {}}}
|
||||
)
|
||||
assert result == []
|
||||
|
||||
def test_close_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests closing a nonexistent index raises an error."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.close_index()
|
||||
|
||||
def test_open_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests opening a nonexistent index raises an error."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.open_index()
|
||||
with pytest.raises(NotImplementedError):
|
||||
test_client.update_settings(settings={})
|
||||
|
||||
def test_create_and_delete_search_pipeline(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
"""Tests for `cloud_beat_task_generator`'s tenant work-gating logic.
|
||||
|
||||
Exercises the gate-read path end-to-end against real Redis. The Celery
|
||||
`.app.send_task` is mocked so we can count dispatches without actually
|
||||
sending messages.
|
||||
|
||||
Requires a running Redis instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest \
|
||||
backend/tests/external_dependency_unit/tenant_work_gating/test_gate_generator.py
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.background.celery.tasks.cloud import tasks as cloud_tasks
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis import redis_tenant_work_gating as twg
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_tenant_work_gating import _SET_KEY
|
||||
from onyx.redis.redis_tenant_work_gating import mark_tenant_active
|
||||
|
||||
|
||||
_TENANT_A = "tenant_aaaa0000-0000-0000-0000-000000000001"
|
||||
_TENANT_B = "tenant_bbbb0000-0000-0000-0000-000000000002"
|
||||
_TENANT_C = "tenant_cccc0000-0000-0000-0000-000000000003"
|
||||
_ALL_TEST_TENANTS = [_TENANT_A, _TENANT_B, _TENANT_C]
|
||||
_FANOUT_KEY_PREFIX = cloud_tasks._FULL_FANOUT_TIMESTAMP_KEY_PREFIX
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _multi_tenant_true() -> Generator[None, None, None]:
|
||||
with patch.object(twg, "MULTI_TENANT", True):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_redis() -> Generator[None, None, None]:
|
||||
"""Clear the active set AND the per-task full-fanout timestamp so each
|
||||
test starts fresh."""
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
r.delete(_SET_KEY)
|
||||
r.delete(f"{_FANOUT_KEY_PREFIX}:test_task")
|
||||
r.delete("runtime:tenant_work_gating:enabled")
|
||||
r.delete("runtime:tenant_work_gating:enforce")
|
||||
yield
|
||||
r.delete(_SET_KEY)
|
||||
r.delete(f"{_FANOUT_KEY_PREFIX}:test_task")
|
||||
r.delete("runtime:tenant_work_gating:enabled")
|
||||
r.delete("runtime:tenant_work_gating:enforce")
|
||||
|
||||
|
||||
def _invoke_generator(
|
||||
*,
|
||||
work_gated: bool,
|
||||
enabled: bool,
|
||||
enforce: bool,
|
||||
tenant_ids: list[str],
|
||||
full_fanout_interval_seconds: int = 1200,
|
||||
ttl_seconds: int = 1800,
|
||||
) -> MagicMock:
|
||||
"""Helper: call the generator with runtime flags fixed and the Celery
|
||||
app mocked. Returns the mock so callers can assert on send_task calls."""
|
||||
mock_app = MagicMock()
|
||||
# The task binds `self` = the task itself when invoked via `.run()`;
|
||||
# patch its `.app` so `self.app.send_task` routes to our mock.
|
||||
with (
|
||||
patch.object(cloud_tasks.cloud_beat_task_generator, "app", mock_app),
|
||||
patch.object(cloud_tasks, "get_all_tenant_ids", return_value=list(tenant_ids)),
|
||||
patch.object(cloud_tasks, "get_gated_tenants", return_value=set()),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_enabled",
|
||||
return_value=enabled,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_enforce",
|
||||
return_value=enforce,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_full_fanout_interval_seconds",
|
||||
return_value=full_fanout_interval_seconds,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_ttl_seconds",
|
||||
return_value=ttl_seconds,
|
||||
),
|
||||
):
|
||||
cloud_tasks.cloud_beat_task_generator.run(
|
||||
task_name="test_task",
|
||||
work_gated=work_gated,
|
||||
)
|
||||
return mock_app
|
||||
|
||||
|
||||
def _dispatched_tenants(mock_app: MagicMock) -> list[str]:
|
||||
"""Pull tenant_ids out of each send_task call for assertion."""
|
||||
return [c.kwargs["kwargs"]["tenant_id"] for c in mock_app.send_task.call_args_list]
|
||||
|
||||
|
||||
def _seed_recent_full_fanout_timestamp() -> None:
|
||||
"""Pre-seed the per-task timestamp so the interval-elapsed branch
|
||||
reports False, i.e. the gate enforces normally instead of going into
|
||||
full-fanout on first invocation."""
|
||||
import time as _t
|
||||
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
r.set(f"{_FANOUT_KEY_PREFIX}:test_task", str(int(_t.time() * 1000)))
|
||||
|
||||
|
||||
def test_enforce_skips_unmarked_tenants() -> None:
|
||||
"""With enable+enforce on (interval NOT elapsed), only tenants in the
|
||||
active set get dispatched."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert dispatched == [_TENANT_A]
|
||||
|
||||
|
||||
def test_shadow_mode_dispatches_all_tenants() -> None:
|
||||
"""enabled=True, enforce=False: gate computes skip but still dispatches."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=False,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_full_fanout_cycle_dispatches_all_tenants() -> None:
|
||||
"""First invocation (no prior timestamp → interval considered elapsed)
|
||||
counts as full-fanout; every tenant gets dispatched even under enforce."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_redis_unavailable_fails_open() -> None:
|
||||
"""When `get_active_tenants` returns None (simulated Redis outage) the
|
||||
gate treats the invocation as full-fanout and dispatches everyone —
|
||||
even when the interval hasn't elapsed and enforce is on."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
with patch.object(cloud_tasks, "get_active_tenants", return_value=None):
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_work_gated_false_bypasses_gate_entirely() -> None:
|
||||
"""Beat templates that don't opt in (`work_gated=False`) never consult
|
||||
the set — no matter the flag state."""
|
||||
# Even with enforce on and nothing in the set, all tenants dispatch.
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=False,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_gate_disabled_dispatches_everyone_regardless_of_enforce() -> None:
|
||||
"""enabled=False means the gate isn't computed — dispatch is unchanged."""
|
||||
# Intentionally don't add anyone to the set.
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=False,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
@@ -247,7 +247,7 @@ class DATestSettings(BaseModel):
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: DATestGatingType = DATestGatingType.NONE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
image_extraction_and_analysis_enabled: bool | None = True
|
||||
image_extraction_and_analysis_enabled: bool | None = False
|
||||
search_time_image_analysis_enabled: bool | None = False
|
||||
|
||||
|
||||
|
||||
@@ -16,14 +16,12 @@ from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import CallToolResult
|
||||
from mcp.types import TextContent
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from tests.integration.common_utils.constants import MCP_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.pat import PATManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
@@ -36,7 +34,6 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
# Constants
|
||||
MCP_SEARCH_TOOL = "search_indexed_documents"
|
||||
INDEXED_SOURCES_RESOURCE_URI = "resource://indexed_sources"
|
||||
DOCUMENT_SETS_RESOURCE_URI = "resource://document_sets"
|
||||
DEFAULT_SEARCH_LIMIT = 5
|
||||
STREAMABLE_HTTP_URL = f"{MCP_SERVER_URL.rstrip('/')}/?transportType=streamable-http"
|
||||
|
||||
@@ -76,22 +73,19 @@ def _extract_tool_payload(result: CallToolResult) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _call_search_tool(
|
||||
headers: dict[str, str],
|
||||
query: str,
|
||||
limit: int = DEFAULT_SEARCH_LIMIT,
|
||||
document_set_names: list[str] | None = None,
|
||||
headers: dict[str, str], query: str, limit: int = DEFAULT_SEARCH_LIMIT
|
||||
) -> CallToolResult:
|
||||
"""Call the search_indexed_documents tool via MCP."""
|
||||
|
||||
async def _action(session: ClientSession) -> CallToolResult:
|
||||
await session.initialize()
|
||||
arguments: dict[str, Any] = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
}
|
||||
if document_set_names is not None:
|
||||
arguments["document_set_names"] = document_set_names
|
||||
return await session.call_tool(MCP_SEARCH_TOOL, arguments)
|
||||
return await session.call_tool(
|
||||
MCP_SEARCH_TOOL,
|
||||
{
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
return _run_with_mcp_session(headers, _action)
|
||||
|
||||
@@ -244,106 +238,3 @@ def test_mcp_search_respects_acl_filters(
|
||||
blocked_payload = _extract_tool_payload(blocked_result)
|
||||
assert blocked_payload["total_results"] == 0
|
||||
assert blocked_payload["documents"] == []
|
||||
|
||||
|
||||
def test_mcp_search_filters_by_document_set(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Passing document_set_names should scope results to the named set."""
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
api_key = APIKeyManager.create(user_performing_action=admin_user)
|
||||
cc_pair_in_set = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_out_of_set = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
shared_phrase = "document-set-filter-shared-phrase"
|
||||
in_set_content = f"{shared_phrase} inside curated set"
|
||||
out_of_set_content = f"{shared_phrase} outside curated set"
|
||||
|
||||
_seed_document_and_wait_for_indexing(
|
||||
cc_pair=cc_pair_in_set,
|
||||
content=in_set_content,
|
||||
api_key=api_key,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
_seed_document_and_wait_for_indexing(
|
||||
cc_pair=cc_pair_out_of_set,
|
||||
content=out_of_set_content,
|
||||
api_key=api_key,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
doc_set = DocumentSetManager.create(
|
||||
cc_pair_ids=[cc_pair_in_set.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
DocumentSetManager.wait_for_sync(
|
||||
user_performing_action=admin_user,
|
||||
document_sets_to_check=[doc_set],
|
||||
)
|
||||
|
||||
headers = _auth_headers(admin_user, name="mcp-doc-set-filter")
|
||||
|
||||
# The document_sets resource should surface the newly created set so MCP
|
||||
# clients can discover which values to pass to document_set_names.
|
||||
async def _list_resources(session: ClientSession) -> Any:
|
||||
await session.initialize()
|
||||
resources = await session.list_resources()
|
||||
contents = await session.read_resource(AnyUrl(DOCUMENT_SETS_RESOURCE_URI))
|
||||
return resources, contents
|
||||
|
||||
resources_result, doc_sets_contents = _run_with_mcp_session(
|
||||
headers, _list_resources
|
||||
)
|
||||
resource_uris = {str(resource.uri) for resource in resources_result.resources}
|
||||
assert DOCUMENT_SETS_RESOURCE_URI in resource_uris
|
||||
doc_sets_payload = json.loads(doc_sets_contents.contents[0].text)
|
||||
exposed_names = {entry["name"] for entry in doc_sets_payload}
|
||||
assert doc_set.name in exposed_names
|
||||
|
||||
# Without the filter both documents are visible.
|
||||
unfiltered_payload = _extract_tool_payload(
|
||||
_call_search_tool(headers, shared_phrase, limit=10)
|
||||
)
|
||||
unfiltered_contents = [
|
||||
doc.get("content") or "" for doc in unfiltered_payload["documents"]
|
||||
]
|
||||
assert any(in_set_content in content for content in unfiltered_contents)
|
||||
assert any(out_of_set_content in content for content in unfiltered_contents)
|
||||
|
||||
# With the document set filter only the in-set document is returned.
|
||||
filtered_payload = _extract_tool_payload(
|
||||
_call_search_tool(
|
||||
headers,
|
||||
shared_phrase,
|
||||
limit=10,
|
||||
document_set_names=[doc_set.name],
|
||||
)
|
||||
)
|
||||
filtered_contents = [
|
||||
doc.get("content") or "" for doc in filtered_payload["documents"]
|
||||
]
|
||||
assert filtered_payload["total_results"] >= 1
|
||||
assert any(in_set_content in content for content in filtered_contents)
|
||||
assert all(out_of_set_content not in content for content in filtered_contents)
|
||||
|
||||
# An empty document_set_names should behave like "no filter" (normalized
|
||||
# to None), not "match zero sets".
|
||||
empty_list_payload = _extract_tool_payload(
|
||||
_call_search_tool(
|
||||
headers,
|
||||
shared_phrase,
|
||||
limit=10,
|
||||
document_set_names=[],
|
||||
)
|
||||
)
|
||||
empty_list_contents = [
|
||||
doc.get("content") or "" for doc in empty_list_payload["documents"]
|
||||
]
|
||||
assert any(in_set_content in content for content in empty_list_contents)
|
||||
assert any(out_of_set_content in content for content in empty_list_contents)
|
||||
|
||||
@@ -8,10 +8,8 @@ import io
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
@@ -121,31 +119,3 @@ def test_public_assistant_with_user_files(
|
||||
assert (
|
||||
len(chat_history) >= 2
|
||||
), "Expected at least 2 messages (user message and assistant response)"
|
||||
|
||||
|
||||
def test_cannot_download_other_users_file_via_chat_file_endpoint(
|
||||
user_file_setup: UserFileTestSetup,
|
||||
) -> None:
|
||||
storage_file_id = user_file_setup.user1_file_descriptor["id"]
|
||||
user_file_id = user_file_setup.user1_file_id
|
||||
|
||||
owner_response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/file/{storage_file_id}",
|
||||
headers=user_file_setup.user1_file_owner.headers,
|
||||
)
|
||||
assert owner_response.status_code == 200
|
||||
assert owner_response.content, "Owner should receive the file contents"
|
||||
|
||||
for file_id in (storage_file_id, user_file_id):
|
||||
user2_response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/file/{file_id}",
|
||||
headers=user_file_setup.user2_non_owner.headers,
|
||||
)
|
||||
assert user2_response.status_code in (
|
||||
403,
|
||||
404,
|
||||
), (
|
||||
f"Expected access denied for non-owner, got {user2_response.status_code} "
|
||||
f"when fetching file_id={file_id}"
|
||||
)
|
||||
assert user2_response.content != owner_response.content
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Unit tests for the require-score check in verify_captcha_token."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth import captcha as captcha_module
|
||||
from onyx.auth.captcha import CaptchaVerificationError
|
||||
from onyx.auth.captcha import verify_captcha_token
|
||||
|
||||
|
||||
def _fake_httpx_client_returning(payload: dict) -> MagicMock:
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status = MagicMock()
|
||||
resp.json = MagicMock(return_value=payload)
|
||||
client = MagicMock()
|
||||
client.post = AsyncMock(return_value=resp)
|
||||
client.__aenter__ = AsyncMock(return_value=client)
|
||||
client.__aexit__ = AsyncMock(return_value=None)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_when_score_missing() -> None:
|
||||
"""Siteverify response with no score field is rejected outright —
|
||||
closes the accidental 'test secret in prod' bypass path."""
|
||||
client = _fake_httpx_client_returning(
|
||||
{"success": True, "hostname": "testkey.google.com"}
|
||||
)
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
|
||||
):
|
||||
with pytest.raises(CaptchaVerificationError, match="missing score"):
|
||||
await verify_captcha_token("test-token", expected_action="signup")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_when_score_present_and_above_threshold() -> None:
|
||||
"""Sanity check the happy path still works with the tighter score rule."""
|
||||
client = _fake_httpx_client_returning(
|
||||
{
|
||||
"success": True,
|
||||
"score": 0.9,
|
||||
"action": "signup",
|
||||
"hostname": "cloud.onyx.app",
|
||||
}
|
||||
)
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
|
||||
):
|
||||
# Should not raise.
|
||||
await verify_captcha_token("fresh-token", expected_action="signup")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_when_score_below_threshold() -> None:
|
||||
"""A score present but below threshold still rejects (existing behavior,
|
||||
guarding against regression from this PR's restructure)."""
|
||||
client = _fake_httpx_client_returning(
|
||||
{
|
||||
"success": True,
|
||||
"score": 0.1,
|
||||
"action": "signup",
|
||||
"hostname": "cloud.onyx.app",
|
||||
}
|
||||
)
|
||||
with (
|
||||
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
|
||||
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
|
||||
):
|
||||
with pytest.raises(
|
||||
CaptchaVerificationError, match="suspicious activity detected"
|
||||
):
|
||||
await verify_captcha_token("low-score-token", expected_action="signup")
|
||||
@@ -69,9 +69,6 @@ class TestGongConnectorCheckpoint:
|
||||
workspace_ids=["ws1", None],
|
||||
workspace_index=1,
|
||||
cursor="abc123",
|
||||
pending_transcripts={"call1": _make_transcript("call1")},
|
||||
pending_call_details_attempts=2,
|
||||
pending_retry_after=1234567890.5,
|
||||
)
|
||||
json_str = original.model_dump_json()
|
||||
restored = connector.validate_checkpoint_json(json_str)
|
||||
@@ -234,11 +231,7 @@ class TestGongConnectorCheckpoint:
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""Missing call details persist across checkpoint invocations and
|
||||
eventually yield ConnectorFailure once MAX_CALL_DETAILS_ATTEMPTS is hit.
|
||||
No in-call sleep — retries happen on subsequent invocations, gated by
|
||||
the wall-clock retry-after deadline on the checkpoint.
|
||||
"""
|
||||
"""When call details are missing after retries, yield ConnectorFailure."""
|
||||
transcript_response = MagicMock()
|
||||
transcript_response.status_code = 200
|
||||
transcript_response.json.return_value = {
|
||||
@@ -264,42 +257,23 @@ class TestGongConnectorCheckpoint:
|
||||
failures: list[ConnectorFailure] = []
|
||||
docs: list[Document] = []
|
||||
|
||||
# Jump the clock past any retry deadline on each invocation so we
|
||||
# exercise the retry path without real sleeping. The test for the
|
||||
# backoff-gate itself lives in test_backoff_gate_prevents_retry_too_soon.
|
||||
fake_now = [1_000_000.0]
|
||||
|
||||
def _advance_clock() -> float:
|
||||
fake_now[0] += 10_000.0
|
||||
return fake_now[0]
|
||||
|
||||
invocation_cap = GongConnector.MAX_CALL_DETAILS_ATTEMPTS + 5
|
||||
with patch(
|
||||
"onyx.connectors.gong.connector.time.time", side_effect=_advance_clock
|
||||
):
|
||||
for _ in range(invocation_cap):
|
||||
if not checkpoint.has_more:
|
||||
break
|
||||
generator = connector.load_from_checkpoint(0, fake_now[0], checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, ConnectorFailure):
|
||||
failures.append(item)
|
||||
elif isinstance(item, Document):
|
||||
docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
with patch("onyx.connectors.gong.connector.time.sleep"):
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, ConnectorFailure):
|
||||
failures.append(item)
|
||||
elif isinstance(item, Document):
|
||||
docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(docs) == 0
|
||||
assert len(failures) == 1
|
||||
assert failures[0].failed_document is not None
|
||||
assert failures[0].failed_document.document_id == "call1"
|
||||
assert checkpoint.has_more is False
|
||||
assert checkpoint.pending_transcripts == {}
|
||||
assert checkpoint.pending_call_details_attempts == 0
|
||||
assert checkpoint.pending_retry_after is None
|
||||
assert mock_request.call_count == 1 + GongConnector.MAX_CALL_DETAILS_ATTEMPTS
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_multi_workspace_iteration(
|
||||
@@ -407,14 +381,12 @@ class TestGongConnectorCheckpoint:
|
||||
assert checkpoint.workspace_index == 1
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_partial_details_defers_and_resolves_next_invocation(
|
||||
def test_retry_only_fetches_missing_ids(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""A transcript whose call details are missing gets stashed into
|
||||
pending_transcripts and resolves on a later checkpoint invocation.
|
||||
Resolved docs are yielded in the order they become available."""
|
||||
"""Retry for missing call details should only re-request the missing IDs."""
|
||||
transcript_response = MagicMock()
|
||||
transcript_response.status_code = 200
|
||||
transcript_response.json.return_value = {
|
||||
@@ -432,7 +404,7 @@ class TestGongConnectorCheckpoint:
|
||||
"calls": [_make_call_detail("call1", "Call One")]
|
||||
}
|
||||
|
||||
# Second fetch (next invocation): returns call2
|
||||
# Second fetch (retry): returns call2
|
||||
missing_details = MagicMock()
|
||||
missing_details.status_code = 200
|
||||
missing_details.json.return_value = {
|
||||
@@ -452,48 +424,19 @@ class TestGongConnectorCheckpoint:
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
|
||||
fake_now = [1_000_000.0]
|
||||
|
||||
def _advance_clock() -> float:
|
||||
fake_now[0] += 10_000.0
|
||||
return fake_now[0]
|
||||
|
||||
with patch(
|
||||
"onyx.connectors.gong.connector.time.time", side_effect=_advance_clock
|
||||
):
|
||||
# Invocation 1: fetches page + details, yields call1, stashes call2
|
||||
generator = connector.load_from_checkpoint(0, fake_now[0], checkpoint)
|
||||
with patch("onyx.connectors.gong.connector.time.sleep"):
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].semantic_identifier == "Call One"
|
||||
assert "call2" in checkpoint.pending_transcripts
|
||||
assert checkpoint.pending_call_details_attempts == 1
|
||||
assert checkpoint.pending_retry_after is not None
|
||||
assert checkpoint.has_more is True
|
||||
|
||||
# Invocation 2: retries missing (only call2), yields it, clears pending
|
||||
generator = connector.load_from_checkpoint(0, fake_now[0], checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].semantic_identifier == "Call One"
|
||||
assert docs[1].semantic_identifier == "Call Two"
|
||||
assert checkpoint.pending_transcripts == {}
|
||||
assert checkpoint.pending_call_details_attempts == 0
|
||||
assert checkpoint.pending_retry_after is None
|
||||
|
||||
# Verify: 3 API calls total (1 transcript + 1 full details + 1 retry for missing only)
|
||||
assert mock_request.call_count == 3
|
||||
@@ -501,107 +444,6 @@ class TestGongConnectorCheckpoint:
|
||||
retry_call_body = mock_request.call_args_list[2][1]["json"]
|
||||
assert retry_call_body["filter"]["callIds"] == ["call2"]
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_backoff_gate_prevents_retry_too_soon(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""If the retry-after deadline hasn't elapsed, _resolve_pending must
|
||||
NOT issue a /v2/calls/extensive request. Prevents burning through
|
||||
MAX_CALL_DETAILS_ATTEMPTS when workers re-invoke tightly.
|
||||
"""
|
||||
pending_transcript = _make_transcript("call1")
|
||||
fixed_now = 1_000_000.0
|
||||
# Deadline is 30s in the future from fixed_now
|
||||
retry_after = fixed_now + 30
|
||||
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=[None],
|
||||
workspace_index=0,
|
||||
pending_transcripts={"call1": pending_transcript},
|
||||
pending_call_details_attempts=1,
|
||||
pending_retry_after=retry_after,
|
||||
)
|
||||
|
||||
with patch("onyx.connectors.gong.connector.time.time", return_value=fixed_now):
|
||||
generator = connector.load_from_checkpoint(0, fixed_now, checkpoint)
|
||||
try:
|
||||
while True:
|
||||
next(generator)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
# No API calls should have been made — we were inside the backoff window
|
||||
mock_request.assert_not_called()
|
||||
# Pending state preserved for later retry
|
||||
assert "call1" in checkpoint.pending_transcripts
|
||||
assert checkpoint.pending_call_details_attempts == 1
|
||||
assert checkpoint.pending_retry_after == retry_after
|
||||
assert checkpoint.has_more is True
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_pending_retry_does_not_block_on_time_sleep(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""Pending-transcript retry must never call time.sleep() with a
|
||||
non-trivial delay — spacing between retries is enforced via the
|
||||
wall-clock retry-after deadline stored on the checkpoint, not by
|
||||
blocking inside load_from_checkpoint.
|
||||
"""
|
||||
transcript_response = MagicMock()
|
||||
transcript_response.status_code = 200
|
||||
transcript_response.json.return_value = {
|
||||
"callTranscripts": [_make_transcript("call1")],
|
||||
"records": {},
|
||||
}
|
||||
empty_details = MagicMock()
|
||||
empty_details.status_code = 200
|
||||
empty_details.json.return_value = {"calls": []}
|
||||
|
||||
mock_request.side_effect = [transcript_response] + [
|
||||
empty_details
|
||||
] * GongConnector.MAX_CALL_DETAILS_ATTEMPTS
|
||||
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=[None],
|
||||
workspace_index=0,
|
||||
)
|
||||
|
||||
fake_now = [1_000_000.0]
|
||||
|
||||
def _advance_clock() -> float:
|
||||
fake_now[0] += 10_000.0
|
||||
return fake_now[0]
|
||||
|
||||
with (
|
||||
patch("onyx.connectors.gong.connector.time.sleep") as mock_sleep,
|
||||
patch(
|
||||
"onyx.connectors.gong.connector.time.time", side_effect=_advance_clock
|
||||
),
|
||||
):
|
||||
invocation_cap = GongConnector.MAX_CALL_DETAILS_ATTEMPTS + 5
|
||||
for _ in range(invocation_cap):
|
||||
if not checkpoint.has_more:
|
||||
break
|
||||
generator = connector.load_from_checkpoint(0, fake_now[0], checkpoint)
|
||||
try:
|
||||
while True:
|
||||
next(generator)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
# The only legitimate sleep is the sub-second throttle in
|
||||
# _throttled_request (<= MIN_REQUEST_INTERVAL). Assert we never
|
||||
# sleep for anything close to the per-retry backoff delays.
|
||||
for call in mock_sleep.call_args_list:
|
||||
delay_arg = call.args[0] if call.args else 0
|
||||
assert delay_arg <= GongConnector.MIN_REQUEST_INTERVAL
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_expired_cursor_restarts_workspace(
|
||||
self,
|
||||
|
||||
@@ -287,140 +287,3 @@ class TestFailedFolderIdsByEmail:
|
||||
)
|
||||
|
||||
assert len(failed_map) == 0
|
||||
|
||||
|
||||
class TestOrphanedPathBackfill:
|
||||
def _make_failed_map(
|
||||
self, entries: dict[str, set[str]]
|
||||
) -> ThreadSafeDict[str, ThreadSafeSet[str]]:
|
||||
return ThreadSafeDict({k: ThreadSafeSet(v) for k, v in entries.items()})
|
||||
|
||||
def _make_file(self, parent_id: str) -> MagicMock:
|
||||
file = MagicMock()
|
||||
file.user_email = "retriever@example.com"
|
||||
file.drive_file = {"parents": [parent_id]}
|
||||
return file
|
||||
|
||||
def test_backfills_intermediate_folders_into_failed_map(self) -> None:
|
||||
"""When a walk dead-ends at a confirmed orphan, all intermediate folder
|
||||
IDs must be added to failed_folder_ids_by_email for both emails so
|
||||
future files short-circuit via _get_folder_metadata's cache check."""
|
||||
connector = _make_connector()
|
||||
|
||||
# Chain: folderA -> folderB -> folderC (confirmed orphan)
|
||||
failed_map = self._make_failed_map(
|
||||
{
|
||||
"retriever@example.com": {"folderC"},
|
||||
"admin@example.com": {"folderC"},
|
||||
}
|
||||
)
|
||||
|
||||
folder_a = {"id": "folderA", "name": "A", "parents": ["folderB"]}
|
||||
folder_b = {"id": "folderB", "name": "B", "parents": ["folderC"]}
|
||||
|
||||
def mock_get_folder(
|
||||
_service: MagicMock, folder_id: str, _field_type: DriveFileFieldType
|
||||
) -> dict | None:
|
||||
if folder_id == "folderA":
|
||||
return folder_a
|
||||
if folder_id == "folderB":
|
||||
return folder_b
|
||||
return None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_drive_service",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_folder_metadata",
|
||||
side_effect=mock_get_folder,
|
||||
),
|
||||
):
|
||||
connector._get_new_ancestors_for_files(
|
||||
files=[self._make_file("folderA")],
|
||||
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
failed_folder_ids_by_email=failed_map,
|
||||
)
|
||||
|
||||
# Both emails confirmed folderC as orphan, so both get the backfill
|
||||
for email in ("retriever@example.com", "admin@example.com"):
|
||||
cached = failed_map.get(email, ThreadSafeSet())
|
||||
assert "folderA" in cached
|
||||
assert "folderB" in cached
|
||||
assert "folderC" in cached
|
||||
|
||||
def test_backfills_only_for_confirming_email(self) -> None:
|
||||
"""Only the email that confirmed the orphan gets the path backfilled."""
|
||||
connector = _make_connector()
|
||||
|
||||
# Only retriever confirmed folderC as orphan; admin has no entry
|
||||
failed_map = self._make_failed_map({"retriever@example.com": {"folderC"}})
|
||||
|
||||
folder_a = {"id": "folderA", "name": "A", "parents": ["folderB"]}
|
||||
folder_b = {"id": "folderB", "name": "B", "parents": ["folderC"]}
|
||||
|
||||
def mock_get_folder(
|
||||
_service: MagicMock, folder_id: str, _field_type: DriveFileFieldType
|
||||
) -> dict | None:
|
||||
if folder_id == "folderA":
|
||||
return folder_a
|
||||
if folder_id == "folderB":
|
||||
return folder_b
|
||||
return None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_drive_service",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_folder_metadata",
|
||||
side_effect=mock_get_folder,
|
||||
),
|
||||
):
|
||||
connector._get_new_ancestors_for_files(
|
||||
files=[self._make_file("folderA")],
|
||||
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
failed_folder_ids_by_email=failed_map,
|
||||
)
|
||||
|
||||
retriever_cached = failed_map.get("retriever@example.com", ThreadSafeSet())
|
||||
assert "folderA" in retriever_cached
|
||||
assert "folderB" in retriever_cached
|
||||
|
||||
# admin did not confirm the orphan — must not get the backfill
|
||||
assert failed_map.get("admin@example.com") is None
|
||||
|
||||
def test_short_circuits_on_backfilled_intermediate(self) -> None:
|
||||
"""A second file whose parent is already in failed_folder_ids_by_email
|
||||
must not trigger any folder metadata API calls."""
|
||||
connector = _make_connector()
|
||||
|
||||
# folderA already in the failed map from a previous walk
|
||||
failed_map = self._make_failed_map(
|
||||
{
|
||||
"retriever@example.com": {"folderA"},
|
||||
"admin@example.com": {"folderA"},
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_drive_service",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_folder_metadata"
|
||||
) as mock_api,
|
||||
):
|
||||
connector._get_new_ancestors_for_files(
|
||||
files=[self._make_file("folderA")],
|
||||
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
failed_folder_ids_by_email=failed_map,
|
||||
)
|
||||
|
||||
mock_api.assert_not_called()
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
"""Unit tests for SharepointConnector.load_credentials sp_tenant_domain resolution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
|
||||
SITE_URL = "https://mytenant.sharepoint.com/sites/MySite"
|
||||
EXPECTED_TENANT_DOMAIN = "mytenant"
|
||||
|
||||
CLIENT_SECRET_CREDS = {
|
||||
"authentication_method": "client_secret",
|
||||
"sp_client_id": "fake-client-id",
|
||||
"sp_client_secret": "fake-client-secret",
|
||||
"sp_directory_id": "fake-directory-id",
|
||||
}
|
||||
|
||||
CERTIFICATE_CREDS = {
|
||||
"authentication_method": "certificate",
|
||||
"sp_client_id": "fake-client-id",
|
||||
"sp_directory_id": "fake-directory-id",
|
||||
"sp_private_key": base64.b64encode(b"fake-pfx-data").decode(),
|
||||
"sp_certificate_password": "fake-password",
|
||||
}
|
||||
|
||||
|
||||
def _make_mock_msal() -> MagicMock:
|
||||
mock_app = MagicMock()
|
||||
mock_app.acquire_token_for_client.return_value = {"access_token": "fake-token"}
|
||||
return mock_app
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_client_secret_with_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
) -> None:
|
||||
"""client_secret auth + include_site_pages=True must resolve sp_tenant_domain."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=True)
|
||||
|
||||
connector.load_credentials(CLIENT_SECRET_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_client_secret_without_site_pages_still_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
) -> None:
|
||||
"""client_secret auth + include_site_pages=False must still resolve sp_tenant_domain
|
||||
because _create_rest_client_context is also called for drive items."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=False)
|
||||
|
||||
connector.load_credentials(CLIENT_SECRET_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.load_certificate_from_pfx")
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_certificate_with_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
mock_load_cert: MagicMock,
|
||||
) -> None:
|
||||
"""certificate auth + include_site_pages=True must resolve sp_tenant_domain."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
mock_load_cert.return_value = MagicMock()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=True)
|
||||
|
||||
connector.load_credentials(CERTIFICATE_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.load_certificate_from_pfx")
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_certificate_without_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
mock_load_cert: MagicMock,
|
||||
) -> None:
|
||||
"""certificate auth + include_site_pages=False must still resolve sp_tenant_domain
|
||||
because _create_rest_client_context is also called for drive items."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
mock_load_cert.return_value = MagicMock()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=False)
|
||||
|
||||
connector.load_credentials(CERTIFICATE_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
@@ -1,194 +0,0 @@
|
||||
"""Unit tests for SharepointConnector site-page slim resilience and
|
||||
validate_connector_settings RoleAssignments permission probe."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
|
||||
SITE_URL = "https://tenant.sharepoint.com/sites/MySite"
|
||||
|
||||
|
||||
def _make_connector() -> SharepointConnector:
|
||||
connector = SharepointConnector(sites=[SITE_URL])
|
||||
connector.msal_app = MagicMock()
|
||||
connector.sp_tenant_domain = "tenant"
|
||||
connector._credential_json = {"sp_client_id": "x", "sp_directory_id": "y"}
|
||||
connector._graph_client = MagicMock()
|
||||
return connector
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _fetch_slim_documents_from_sharepoint — site page error resilience
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector._convert_sitepage_to_slim_document")
|
||||
@patch(
|
||||
"onyx.connectors.sharepoint.connector.SharepointConnector._create_rest_client_context"
|
||||
)
|
||||
@patch("onyx.connectors.sharepoint.connector.SharepointConnector._fetch_site_pages")
|
||||
@patch("onyx.connectors.sharepoint.connector.SharepointConnector._fetch_driveitems")
|
||||
@patch("onyx.connectors.sharepoint.connector.SharepointConnector.fetch_sites")
|
||||
def test_site_page_error_does_not_crash(
|
||||
mock_fetch_sites: MagicMock,
|
||||
mock_fetch_driveitems: MagicMock,
|
||||
mock_fetch_site_pages: MagicMock,
|
||||
_mock_create_ctx: MagicMock,
|
||||
mock_convert: MagicMock,
|
||||
) -> None:
|
||||
"""A 401 (or any exception) on a site page is caught; remaining pages are processed."""
|
||||
from onyx.connectors.models import SlimDocument
|
||||
|
||||
connector = _make_connector()
|
||||
connector.include_site_documents = False
|
||||
connector.include_site_pages = True
|
||||
|
||||
site = MagicMock()
|
||||
site.url = SITE_URL
|
||||
mock_fetch_sites.return_value = [site]
|
||||
mock_fetch_driveitems.return_value = iter([])
|
||||
|
||||
page_ok = {"id": "1", "webUrl": SITE_URL + "/SitePages/Good.aspx"}
|
||||
page_bad = {"id": "2", "webUrl": SITE_URL + "/SitePages/Bad.aspx"}
|
||||
mock_fetch_site_pages.return_value = [page_bad, page_ok]
|
||||
|
||||
good_slim = SlimDocument(id="1")
|
||||
|
||||
def _convert_side_effect(
|
||||
page: dict, *_args: object, **_kwargs: object
|
||||
) -> SlimDocument: # noqa: ANN001
|
||||
if page["id"] == "2":
|
||||
from office365.runtime.client_request import ClientRequestException
|
||||
|
||||
raise ClientRequestException(MagicMock(status_code=401), None)
|
||||
return good_slim
|
||||
|
||||
mock_convert.side_effect = _convert_side_effect
|
||||
|
||||
results = [
|
||||
doc
|
||||
for batch in connector._fetch_slim_documents_from_sharepoint()
|
||||
for doc in batch
|
||||
if isinstance(doc, SlimDocument)
|
||||
]
|
||||
|
||||
# Only the good page makes it through; bad page is skipped, no exception raised.
|
||||
assert any(d.id == "1" for d in results)
|
||||
assert not any(d.id == "2" for d in results)
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector._convert_sitepage_to_slim_document")
|
||||
@patch(
|
||||
"onyx.connectors.sharepoint.connector.SharepointConnector._create_rest_client_context"
|
||||
)
|
||||
@patch("onyx.connectors.sharepoint.connector.SharepointConnector._fetch_site_pages")
|
||||
@patch("onyx.connectors.sharepoint.connector.SharepointConnector._fetch_driveitems")
|
||||
@patch("onyx.connectors.sharepoint.connector.SharepointConnector.fetch_sites")
|
||||
def test_all_site_pages_fail_does_not_crash(
|
||||
mock_fetch_sites: MagicMock,
|
||||
mock_fetch_driveitems: MagicMock,
|
||||
mock_fetch_site_pages: MagicMock,
|
||||
_mock_create_ctx: MagicMock,
|
||||
mock_convert: MagicMock,
|
||||
) -> None:
|
||||
"""When every site page fails, the generator completes without raising."""
|
||||
connector = _make_connector()
|
||||
connector.include_site_documents = False
|
||||
connector.include_site_pages = True
|
||||
|
||||
site = MagicMock()
|
||||
site.url = SITE_URL
|
||||
mock_fetch_sites.return_value = [site]
|
||||
mock_fetch_driveitems.return_value = iter([])
|
||||
mock_fetch_site_pages.return_value = [
|
||||
{"id": "1", "webUrl": SITE_URL + "/SitePages/A.aspx"},
|
||||
{"id": "2", "webUrl": SITE_URL + "/SitePages/B.aspx"},
|
||||
]
|
||||
mock_convert.side_effect = RuntimeError("context error")
|
||||
|
||||
from onyx.connectors.models import SlimDocument
|
||||
|
||||
# Should not raise; no SlimDocuments in output (only hierarchy nodes).
|
||||
slim_results = [
|
||||
doc
|
||||
for batch in connector._fetch_slim_documents_from_sharepoint()
|
||||
for doc in batch
|
||||
if isinstance(doc, SlimDocument)
|
||||
]
|
||||
assert slim_results == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_connector_settings — RoleAssignments permission probe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [401, 403])
|
||||
@patch("onyx.connectors.sharepoint.connector.requests.get")
|
||||
@patch("onyx.connectors.sharepoint.connector.validate_outbound_http_url")
|
||||
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
|
||||
def test_validate_raises_on_401_or_403(
|
||||
mock_acquire: MagicMock,
|
||||
_mock_validate_url: MagicMock,
|
||||
mock_get: MagicMock,
|
||||
status_code: int,
|
||||
) -> None:
|
||||
"""validate_connector_settings raises ConnectorValidationError when probe returns 401 or 403."""
|
||||
mock_acquire.return_value = MagicMock(accessToken="tok")
|
||||
mock_get.return_value = MagicMock(status_code=status_code)
|
||||
|
||||
connector = _make_connector()
|
||||
|
||||
with pytest.raises(ConnectorValidationError, match="Sites.FullControl.All"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.requests.get")
|
||||
@patch("onyx.connectors.sharepoint.connector.validate_outbound_http_url")
|
||||
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
|
||||
def test_validate_passes_on_200(
|
||||
mock_acquire: MagicMock,
|
||||
_mock_validate_url: MagicMock,
|
||||
mock_get: MagicMock,
|
||||
) -> None:
|
||||
"""validate_connector_settings does not raise when probe returns 200."""
|
||||
mock_acquire.return_value = MagicMock(accessToken="tok")
|
||||
mock_get.return_value = MagicMock(status_code=200)
|
||||
|
||||
connector = _make_connector()
|
||||
connector.validate_connector_settings() # should not raise
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.requests.get")
|
||||
@patch("onyx.connectors.sharepoint.connector.validate_outbound_http_url")
|
||||
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
|
||||
def test_validate_passes_on_network_error(
|
||||
mock_acquire: MagicMock,
|
||||
_mock_validate_url: MagicMock,
|
||||
mock_get: MagicMock,
|
||||
) -> None:
|
||||
"""Network errors during the probe are non-blocking (logged as warning only)."""
|
||||
mock_acquire.return_value = MagicMock(accessToken="tok")
|
||||
mock_get.side_effect = Exception("timeout")
|
||||
|
||||
connector = _make_connector()
|
||||
connector.validate_connector_settings() # should not raise
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.validate_outbound_http_url")
|
||||
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
|
||||
def test_validate_skips_probe_without_credentials(
|
||||
mock_acquire: MagicMock,
|
||||
_mock_validate_url: MagicMock,
|
||||
) -> None:
|
||||
"""Probe is skipped when credentials have not been loaded."""
|
||||
connector = SharepointConnector(sites=[SITE_URL])
|
||||
# msal_app and sp_tenant_domain are None — probe must be skipped.
|
||||
connector.validate_connector_settings() # should not raise
|
||||
mock_acquire.assert_not_called()
|
||||
@@ -1,243 +0,0 @@
|
||||
"""Unit tests for WebConnector.retrieve_all_slim_docs (slim pruning path)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.web.connector import WEB_CONNECTOR_VALID_SETTINGS
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
|
||||
BASE_URL = "http://example.com"
|
||||
|
||||
SINGLE_PAGE_HTML = (
|
||||
"<html><body><p>Content that should not appear in slim output</p></body></html>"
|
||||
)
|
||||
|
||||
RECURSIVE_ROOT_HTML = """
|
||||
<html><body>
|
||||
<a href="/page2">Page 2</a>
|
||||
<a href="/page3">Page 3</a>
|
||||
</body></html>
|
||||
"""
|
||||
|
||||
PAGE2_HTML = "<html><body><p>page 2</p></body></html>"
|
||||
PAGE3_HTML = "<html><body><p>page 3</p></body></html>"
|
||||
|
||||
|
||||
def _make_playwright_context_mock(url_to_html: dict[str, str]) -> MagicMock:
|
||||
"""Return a BrowserContext mock whose pages respond based on goto URL."""
|
||||
context = MagicMock()
|
||||
|
||||
def _new_page() -> MagicMock:
|
||||
page = MagicMock()
|
||||
visited: list[str] = []
|
||||
|
||||
def _goto(url: str, **kwargs: Any) -> MagicMock: # noqa: ARG001
|
||||
visited.append(url)
|
||||
page.url = url
|
||||
response = MagicMock()
|
||||
response.status = 200
|
||||
response.header_value.return_value = None # no cf-ray
|
||||
return response
|
||||
|
||||
def _content() -> str:
|
||||
return url_to_html.get(
|
||||
visited[-1] if visited else "", "<html><body></body></html>"
|
||||
)
|
||||
|
||||
page.goto.side_effect = _goto
|
||||
page.content.side_effect = _content
|
||||
return page
|
||||
|
||||
context.new_page.side_effect = _new_page
|
||||
return context
|
||||
|
||||
|
||||
def _make_playwright_mock() -> MagicMock:
|
||||
playwright = MagicMock()
|
||||
playwright.stop = MagicMock()
|
||||
return playwright
|
||||
|
||||
|
||||
def _make_page_mock(
|
||||
html: str, cf_ray: str | None = None, status: int = 200
|
||||
) -> MagicMock:
|
||||
"""Return a Playwright page mock with configurable status and CF header."""
|
||||
page = MagicMock()
|
||||
page.url = BASE_URL + "/"
|
||||
response = MagicMock()
|
||||
response.status = status
|
||||
response.header_value.side_effect = lambda h: cf_ray if h == "cf-ray" else None
|
||||
page.goto.return_value = response
|
||||
page.content.return_value = html
|
||||
return page
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_slim_yields_slim_documents(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""retrieve_all_slim_docs yields SlimDocuments with the correct URL as id."""
|
||||
context = _make_playwright_context_mock({BASE_URL + "/": SINGLE_PAGE_HTML})
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
)
|
||||
|
||||
docs = [doc for batch in connector.retrieve_all_slim_docs() for doc in batch]
|
||||
|
||||
assert len(docs) == 1
|
||||
assert isinstance(docs[0], SlimDocument)
|
||||
assert docs[0].id == BASE_URL + "/"
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_slim_skips_content_extraction(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""web_html_cleanup is never called in slim mode."""
|
||||
context = _make_playwright_context_mock({BASE_URL + "/": SINGLE_PAGE_HTML})
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
)
|
||||
|
||||
with patch("onyx.connectors.web.connector.web_html_cleanup") as mock_cleanup:
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
mock_cleanup.assert_not_called()
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_slim_discovers_links_recursively(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""In RECURSIVE mode, internal <a href> links are followed and all URLs yielded."""
|
||||
url_to_html = {
|
||||
BASE_URL + "/": RECURSIVE_ROOT_HTML,
|
||||
BASE_URL + "/page2": PAGE2_HTML,
|
||||
BASE_URL + "/page3": PAGE3_HTML,
|
||||
}
|
||||
context = _make_playwright_context_mock(url_to_html)
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value,
|
||||
)
|
||||
|
||||
ids = {
|
||||
doc.id
|
||||
for batch in connector.retrieve_all_slim_docs()
|
||||
for doc in batch
|
||||
if isinstance(doc, SlimDocument)
|
||||
}
|
||||
|
||||
assert ids == {
|
||||
BASE_URL + "/",
|
||||
BASE_URL + "/page2",
|
||||
BASE_URL + "/page3",
|
||||
}
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_normal_200_skips_5s_wait(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""Normal 200 responses without bot-detection signals skip the 5s render wait."""
|
||||
page = _make_page_mock(SINGLE_PAGE_HTML, cf_ray=None, status=200)
|
||||
context = MagicMock()
|
||||
context.new_page.return_value = page
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
)
|
||||
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
|
||||
page.wait_for_timeout.assert_not_called()
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_cloudflare_applies_5s_wait(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""Pages with a cf-ray header trigger the 5s wait before networkidle."""
|
||||
page = _make_page_mock(SINGLE_PAGE_HTML, cf_ray="abc123-LAX")
|
||||
context = MagicMock()
|
||||
context.new_page.return_value = page
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
)
|
||||
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
|
||||
page.wait_for_timeout.assert_called_once_with(5000)
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.time")
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_403_applies_5s_wait(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
_mock_time: MagicMock,
|
||||
) -> None:
|
||||
"""A 403 response triggers the 5s wait (common bot-detection challenge entry point)."""
|
||||
page = _make_page_mock(SINGLE_PAGE_HTML, cf_ray=None, status=403)
|
||||
context = MagicMock()
|
||||
context.new_page.return_value = page
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
)
|
||||
|
||||
# All retries return 403 so no docs are found — that's expected here.
|
||||
# We only care that the 5s wait fired.
|
||||
try:
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
page.wait_for_timeout.assert_called_with(5000)
|
||||
@@ -1,11 +1,12 @@
|
||||
import io
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
|
||||
from onyx.file_processing.extract_file_text import _sheet_to_csv
|
||||
from onyx.file_processing.extract_file_text import _clean_worksheet_matrix
|
||||
from onyx.file_processing.extract_file_text import _worksheet_to_matrix
|
||||
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
|
||||
@@ -201,179 +202,50 @@ class TestXlsxToText:
|
||||
assert "r3c1" in lines[2] and "r3c2" in lines[2]
|
||||
|
||||
|
||||
class TestSheetToCsvJaggedRows:
|
||||
"""openpyxl's read-only mode yields rows of differing widths when
|
||||
trailing cells are empty. These tests exercise ``_sheet_to_csv``
|
||||
directly because ``_make_xlsx`` (via ``ws.append``) normalizes row
|
||||
widths, so jagged input can only be produced in-memory."""
|
||||
class TestWorksheetToMatrixJaggedRows:
|
||||
"""openpyxl read_only mode can yield rows of differing widths when
|
||||
trailing cells are empty. The matrix must be padded to a rectangle
|
||||
so downstream column cleanup can index safely."""
|
||||
|
||||
def test_shorter_trailing_rows_padded_in_output(self) -> None:
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A", "B", "C"),
|
||||
("X", "Y"),
|
||||
("P",),
|
||||
]
|
||||
)
|
||||
def test_pads_shorter_trailing_rows(self) -> None:
|
||||
ws = MagicMock()
|
||||
ws.iter_rows.return_value = iter(
|
||||
[
|
||||
("A", "B", "C"),
|
||||
("X", "Y"),
|
||||
("P",),
|
||||
]
|
||||
)
|
||||
assert csv_text.split("\n") == ["A,B,C", "X,Y,", "P,,"]
|
||||
matrix = _worksheet_to_matrix(ws)
|
||||
assert matrix == [["A", "B", "C"], ["X", "Y", ""], ["P", "", ""]]
|
||||
|
||||
def test_shorter_leading_row_padded_in_output(self) -> None:
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A",),
|
||||
("X", "Y", "Z"),
|
||||
]
|
||||
)
|
||||
def test_pads_when_first_row_is_shorter(self) -> None:
|
||||
ws = MagicMock()
|
||||
ws.iter_rows.return_value = iter(
|
||||
[
|
||||
("A",),
|
||||
("X", "Y", "Z"),
|
||||
]
|
||||
)
|
||||
assert csv_text.split("\n") == ["A,,", "X,Y,Z"]
|
||||
matrix = _worksheet_to_matrix(ws)
|
||||
assert matrix == [["A", "", ""], ["X", "Y", "Z"]]
|
||||
|
||||
def test_no_index_error_on_jagged_rows(self) -> None:
|
||||
"""Regression: the original dense-matrix version raised IndexError
|
||||
when a later row was shorter than an earlier row whose out-of-range
|
||||
columns happened to be empty."""
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A", "", "", "B"),
|
||||
("X", "Y"),
|
||||
]
|
||||
)
|
||||
def test_clean_worksheet_matrix_no_index_error_on_jagged_rows(self) -> None:
|
||||
"""Regression: previously raised IndexError when a later row was
|
||||
shorter than the first row and the out-of-range column on the
|
||||
first row was empty (so the short-circuit in `all()` did not
|
||||
save us)."""
|
||||
ws = MagicMock()
|
||||
ws.iter_rows.return_value = iter(
|
||||
[
|
||||
("A", "", "", "B"),
|
||||
("X", "Y"),
|
||||
]
|
||||
)
|
||||
assert csv_text.split("\n") == ["A,,,B", "X,Y,,"]
|
||||
|
||||
|
||||
class TestSheetToCsvStreaming:
|
||||
"""Pin the memory-safe streaming contract: empty rows are skipped
|
||||
cheaply, empty-row/column runs are collapsed to at most 2, and sheets
|
||||
with no data return the empty string."""
|
||||
|
||||
def test_empty_rows_between_data_capped_at_two(self) -> None:
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A", "B"),
|
||||
(None, None),
|
||||
(None, None),
|
||||
(None, None),
|
||||
(None, None),
|
||||
(None, None),
|
||||
("C", "D"),
|
||||
]
|
||||
)
|
||||
)
|
||||
# 5 empty rows collapsed to 2
|
||||
assert csv_text.split("\n") == ["A,B", ",", ",", "C,D"]
|
||||
|
||||
def test_empty_rows_at_or_below_cap_preserved(self) -> None:
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A", "B"),
|
||||
(None, None),
|
||||
(None, None),
|
||||
("C", "D"),
|
||||
]
|
||||
)
|
||||
)
|
||||
assert csv_text.split("\n") == ["A,B", ",", ",", "C,D"]
|
||||
|
||||
def test_empty_column_run_capped_at_two(self) -> None:
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A", None, None, None, None, "B"),
|
||||
("C", None, None, None, None, "D"),
|
||||
]
|
||||
)
|
||||
)
|
||||
# 4 empty cols between A and B collapsed to 2
|
||||
assert csv_text.split("\n") == ["A,,,B", "C,,,D"]
|
||||
|
||||
def test_completely_empty_stream_returns_empty_string(self) -> None:
|
||||
assert _sheet_to_csv(iter([])) == ""
|
||||
|
||||
def test_all_rows_empty_returns_empty_string(self) -> None:
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
(None, None),
|
||||
("", ""),
|
||||
(None,),
|
||||
]
|
||||
)
|
||||
)
|
||||
assert csv_text == ""
|
||||
|
||||
def test_trailing_empty_rows_dropped(self) -> None:
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A",),
|
||||
("B",),
|
||||
(None,),
|
||||
(None,),
|
||||
(None,),
|
||||
]
|
||||
)
|
||||
)
|
||||
# Trailing empties are never emitted (no subsequent non-empty row
|
||||
# to flush them against).
|
||||
assert csv_text.split("\n") == ["A", "B"]
|
||||
|
||||
def test_leading_empty_rows_capped_at_two(self) -> None:
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
(None, None),
|
||||
(None, None),
|
||||
(None, None),
|
||||
(None, None),
|
||||
(None, None),
|
||||
("A", "B"),
|
||||
]
|
||||
)
|
||||
)
|
||||
# 5 leading empty rows collapsed to 2
|
||||
assert csv_text.split("\n") == [",", ",", "A,B"]
|
||||
|
||||
def test_cell_cap_truncates_and_appends_marker(self) -> None:
|
||||
"""When total non-empty cells exceeds the cap, scanning stops and
|
||||
a truncation marker row is appended so downstream indexing sees
|
||||
the sheet was cut off."""
|
||||
with patch(
|
||||
"onyx.file_processing.extract_file_text.MAX_XLSX_CELLS_PER_SHEET", 5
|
||||
):
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A", "B", "C"),
|
||||
("D", "E", "F"),
|
||||
("G", "H", "I"),
|
||||
("J", "K", "L"),
|
||||
]
|
||||
)
|
||||
)
|
||||
lines = csv_text.split("\n")
|
||||
assert lines[-1] == "[truncated: sheet exceeded cell limit]"
|
||||
# First two rows (6 cells) trip the cap=5 check after row 2; the
|
||||
# third and fourth rows are never scanned.
|
||||
assert "G" not in csv_text
|
||||
assert "J" not in csv_text
|
||||
|
||||
def test_cell_cap_not_hit_no_marker(self) -> None:
|
||||
"""Under the cap, no truncation marker is appended."""
|
||||
csv_text = _sheet_to_csv(
|
||||
iter(
|
||||
[
|
||||
("A", "B"),
|
||||
("C", "D"),
|
||||
]
|
||||
)
|
||||
)
|
||||
assert "[truncated" not in csv_text
|
||||
matrix = _worksheet_to_matrix(ws)
|
||||
# Must not raise.
|
||||
cleaned = _clean_worksheet_matrix(matrix)
|
||||
assert cleaned == [["A", "", "", "B"], ["X", "Y", "", ""]]
|
||||
|
||||
|
||||
class TestXlsxSheetExtraction:
|
||||
@@ -408,9 +280,10 @@ class TestXlsxSheetExtraction:
|
||||
assert "a" in csv_text
|
||||
assert "b" in csv_text
|
||||
|
||||
def test_empty_sheet_included_with_empty_csv(self) -> None:
|
||||
"""Every sheet in the workbook appears in the result; an empty
|
||||
sheet contributes an empty csv_text alongside its title."""
|
||||
def test_empty_sheet_is_skipped(self) -> None:
|
||||
"""A sheet whose CSV output is empty/whitespace-only should NOT
|
||||
appear in the result — the `if csv_text.strip():` guard filters
|
||||
it out."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Data": [["a", "b"]],
|
||||
@@ -418,17 +291,14 @@ class TestXlsxSheetExtraction:
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 2
|
||||
titles = [title for _csv, title in sheets]
|
||||
assert titles == ["Data", "Empty"]
|
||||
empty_csv = next(csv_text for csv_text, title in sheets if title == "Empty")
|
||||
assert empty_csv == ""
|
||||
assert len(sheets) == 1
|
||||
assert sheets[0][1] == "Data"
|
||||
|
||||
def test_empty_workbook_returns_one_tuple_per_sheet(self) -> None:
|
||||
"""All sheets empty → one empty-csv tuple per sheet."""
|
||||
def test_empty_workbook_returns_empty_list(self) -> None:
|
||||
"""All sheets empty → empty list (not a list of empty tuples)."""
|
||||
xlsx = _make_xlsx({"Sheet1": [], "Sheet2": []})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert sheets == [("", "Sheet1"), ("", "Sheet2")]
|
||||
assert sheets == []
|
||||
|
||||
def test_single_sheet(self) -> None:
|
||||
xlsx = _make_xlsx({"Only": [["x", "y"], ["1", "2"]]})
|
||||
@@ -451,17 +321,6 @@ class TestXlsxSheetExtraction:
|
||||
sheets = xlsx_sheet_extraction(bad_file, file_name="~$temp.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_known_openpyxl_bug_max_value_returns_empty(self) -> None:
|
||||
"""openpyxl's strict descriptor validation rejects font family
|
||||
values >14 with 'Max value is 14'. Treat as a known openpyxl bug
|
||||
and skip the file rather than fail the whole connector batch."""
|
||||
with patch(
|
||||
"onyx.file_processing.extract_file_text.openpyxl.load_workbook",
|
||||
side_effect=ValueError("Max value is 14"),
|
||||
):
|
||||
sheets = xlsx_sheet_extraction(io.BytesIO(b""), file_name="bad_font.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_csv_content_matches_xlsx_to_text_per_sheet(self) -> None:
|
||||
"""For a single-sheet workbook, xlsx_to_text output should equal
|
||||
the csv_text from xlsx_sheet_extraction — they share the same
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
"""Tests for embedding Prometheus metrics."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.server.metrics.embedding import _client_duration
|
||||
from onyx.server.metrics.embedding import _embedding_input_chars_total
|
||||
from onyx.server.metrics.embedding import _embedding_requests_total
|
||||
from onyx.server.metrics.embedding import _embedding_texts_total
|
||||
from onyx.server.metrics.embedding import _embeddings_in_progress
|
||||
from onyx.server.metrics.embedding import LOCAL_PROVIDER_LABEL
|
||||
from onyx.server.metrics.embedding import observe_embedding_client
|
||||
from onyx.server.metrics.embedding import provider_label
|
||||
from onyx.server.metrics.embedding import PROVIDER_LABEL_NAME
|
||||
from onyx.server.metrics.embedding import TEXT_TYPE_LABEL_NAME
|
||||
from onyx.server.metrics.embedding import track_embedding_in_progress
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
class TestProviderLabel:
|
||||
def test_none_maps_to_local(self) -> None:
|
||||
assert provider_label(None) == LOCAL_PROVIDER_LABEL
|
||||
|
||||
def test_enum_maps_to_value(self) -> None:
|
||||
assert provider_label(EmbeddingProvider.OPENAI) == "openai"
|
||||
assert provider_label(EmbeddingProvider.COHERE) == "cohere"
|
||||
|
||||
|
||||
class TestObserveEmbeddingClient:
|
||||
def test_success_records_all_counters(self) -> None:
|
||||
# Precondition.
|
||||
provider = EmbeddingProvider.OPENAI
|
||||
text_type = EmbedTextType.QUERY
|
||||
labels = {
|
||||
PROVIDER_LABEL_NAME: provider.value,
|
||||
TEXT_TYPE_LABEL_NAME: text_type.value,
|
||||
}
|
||||
|
||||
before_requests = _embedding_requests_total.labels(
|
||||
**labels, status="success"
|
||||
)._value.get()
|
||||
before_texts = _embedding_texts_total.labels(**labels)._value.get()
|
||||
before_chars = _embedding_input_chars_total.labels(**labels)._value.get()
|
||||
before_duration_sum = _client_duration.labels(**labels)._sum.get()
|
||||
|
||||
test_duration_s = 0.123
|
||||
test_num_texts = 4
|
||||
test_num_chars = 200
|
||||
|
||||
# Under test.
|
||||
observe_embedding_client(
|
||||
provider=provider,
|
||||
text_type=text_type,
|
||||
duration_s=test_duration_s,
|
||||
num_texts=test_num_texts,
|
||||
num_chars=test_num_chars,
|
||||
success=True,
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert (
|
||||
_embedding_requests_total.labels(**labels, status="success")._value.get()
|
||||
== before_requests + 1
|
||||
)
|
||||
assert (
|
||||
_embedding_texts_total.labels(**labels)._value.get()
|
||||
== before_texts + test_num_texts
|
||||
)
|
||||
assert (
|
||||
_embedding_input_chars_total.labels(**labels)._value.get()
|
||||
== before_chars + test_num_chars
|
||||
)
|
||||
assert (
|
||||
_client_duration.labels(**labels)._sum.get()
|
||||
== before_duration_sum + test_duration_s
|
||||
)
|
||||
|
||||
def test_failure_records_duration_and_failure_counter_only(self) -> None:
|
||||
# Precondition.
|
||||
provider = EmbeddingProvider.COHERE
|
||||
text_type = EmbedTextType.PASSAGE
|
||||
labels = {
|
||||
PROVIDER_LABEL_NAME: provider.value,
|
||||
TEXT_TYPE_LABEL_NAME: text_type.value,
|
||||
}
|
||||
|
||||
before_failure = _embedding_requests_total.labels(
|
||||
**labels, status="failure"
|
||||
)._value.get()
|
||||
before_texts = _embedding_texts_total.labels(**labels)._value.get()
|
||||
before_chars = _embedding_input_chars_total.labels(**labels)._value.get()
|
||||
before_duration_sum = _client_duration.labels(**labels)._sum.get()
|
||||
|
||||
test_duration_s = 0.5
|
||||
test_num_texts = 3
|
||||
test_num_chars = 150
|
||||
|
||||
# Under test.
|
||||
observe_embedding_client(
|
||||
provider=provider,
|
||||
text_type=text_type,
|
||||
duration_s=test_duration_s,
|
||||
num_texts=test_num_texts,
|
||||
num_chars=test_num_chars,
|
||||
success=False,
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
# Failure counter incremented.
|
||||
assert (
|
||||
_embedding_requests_total.labels(**labels, status="failure")._value.get()
|
||||
== before_failure + 1
|
||||
)
|
||||
# Duration still recorded.
|
||||
assert (
|
||||
_client_duration.labels(**labels)._sum.get()
|
||||
== before_duration_sum + test_duration_s
|
||||
)
|
||||
# Throughput counters NOT bumped on failure.
|
||||
assert _embedding_texts_total.labels(**labels)._value.get() == before_texts
|
||||
assert (
|
||||
_embedding_input_chars_total.labels(**labels)._value.get() == before_chars
|
||||
)
|
||||
|
||||
def test_local_provider_uses_local_label(self) -> None:
|
||||
# Precondition.
|
||||
text_type = EmbedTextType.QUERY
|
||||
labels = {
|
||||
PROVIDER_LABEL_NAME: LOCAL_PROVIDER_LABEL,
|
||||
TEXT_TYPE_LABEL_NAME: text_type.value,
|
||||
}
|
||||
before = _embedding_requests_total.labels(
|
||||
**labels, status="success"
|
||||
)._value.get()
|
||||
|
||||
test_duration_s = 0.05
|
||||
test_num_texts = 1
|
||||
test_num_chars = 10
|
||||
|
||||
# Under test.
|
||||
observe_embedding_client(
|
||||
provider=None,
|
||||
text_type=text_type,
|
||||
duration_s=test_duration_s,
|
||||
num_texts=test_num_texts,
|
||||
num_chars=test_num_chars,
|
||||
success=True,
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert (
|
||||
_embedding_requests_total.labels(**labels, status="success")._value.get()
|
||||
== before + 1
|
||||
)
|
||||
|
||||
def test_exceptions_do_not_propagate(self) -> None:
|
||||
with patch.object(
|
||||
_embedding_requests_total,
|
||||
"labels",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
# Must not raise.
|
||||
observe_embedding_client(
|
||||
provider=EmbeddingProvider.OPENAI,
|
||||
text_type=EmbedTextType.QUERY,
|
||||
duration_s=0.1,
|
||||
num_texts=1,
|
||||
num_chars=10,
|
||||
success=True,
|
||||
)
|
||||
|
||||
|
||||
class TestTrackEmbeddingInProgress:
|
||||
def test_gauge_increments_and_decrements(self) -> None:
|
||||
# Precondition.
|
||||
provider = EmbeddingProvider.OPENAI
|
||||
text_type = EmbedTextType.QUERY
|
||||
labels = {
|
||||
PROVIDER_LABEL_NAME: provider.value,
|
||||
TEXT_TYPE_LABEL_NAME: text_type.value,
|
||||
}
|
||||
before = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
|
||||
# Under test.
|
||||
with track_embedding_in_progress(provider, text_type):
|
||||
during = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
assert during == before + 1
|
||||
|
||||
# Postcondition.
|
||||
after = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
assert after == before
|
||||
|
||||
def test_gauge_decrements_on_exception(self) -> None:
|
||||
# Precondition.
|
||||
provider = EmbeddingProvider.COHERE
|
||||
text_type = EmbedTextType.PASSAGE
|
||||
labels = {
|
||||
PROVIDER_LABEL_NAME: provider.value,
|
||||
TEXT_TYPE_LABEL_NAME: text_type.value,
|
||||
}
|
||||
before = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
|
||||
# Under test.
|
||||
raised = False
|
||||
try:
|
||||
with track_embedding_in_progress(provider, text_type):
|
||||
raise ValueError("simulated embedding failure")
|
||||
except ValueError:
|
||||
raised = True
|
||||
assert raised
|
||||
|
||||
# Postcondition.
|
||||
after = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
assert after == before
|
||||
|
||||
def test_local_provider_uses_local_label(self) -> None:
|
||||
# Precondition.
|
||||
text_type = EmbedTextType.QUERY
|
||||
labels = {
|
||||
PROVIDER_LABEL_NAME: LOCAL_PROVIDER_LABEL,
|
||||
TEXT_TYPE_LABEL_NAME: text_type.value,
|
||||
}
|
||||
before = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
|
||||
# Under test.
|
||||
with track_embedding_in_progress(None, text_type):
|
||||
during = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
assert during == before + 1
|
||||
|
||||
# Postcondition.
|
||||
after = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
assert after == before
|
||||
|
||||
def test_inc_exception_does_not_break_call(self) -> None:
|
||||
# Precondition.
|
||||
provider = EmbeddingProvider.VOYAGE
|
||||
text_type = EmbedTextType.QUERY
|
||||
labels = {
|
||||
PROVIDER_LABEL_NAME: provider.value,
|
||||
TEXT_TYPE_LABEL_NAME: text_type.value,
|
||||
}
|
||||
before = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
|
||||
# Under test.
|
||||
with patch.object(
|
||||
_embeddings_in_progress.labels(**labels),
|
||||
"inc",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
# Context manager should still yield without decrementing.
|
||||
with track_embedding_in_progress(provider, text_type):
|
||||
during = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
assert during == before
|
||||
|
||||
# Postcondition.
|
||||
after = _embeddings_in_progress.labels(**labels)._value.get()
|
||||
assert after == before
|
||||
@@ -129,36 +129,12 @@ class TestWorkerHealthCollector:
|
||||
up = families[1]
|
||||
assert up.name == "onyx_celery_worker_up"
|
||||
assert len(up.samples) == 3
|
||||
label_pairs = {
|
||||
(s.labels["worker_type"], s.labels["hostname"]) for s in up.samples
|
||||
}
|
||||
assert label_pairs == {
|
||||
("primary", "host1"),
|
||||
("docfetching", "host1"),
|
||||
("monitoring", "host1"),
|
||||
}
|
||||
# Labels use short names (before @)
|
||||
labels = {s.labels["worker"] for s in up.samples}
|
||||
assert labels == {"primary", "docfetching", "monitoring"}
|
||||
for sample in up.samples:
|
||||
assert sample.value == 1
|
||||
|
||||
def test_replicas_of_same_worker_type_are_distinct(self) -> None:
|
||||
"""Regression: ``docprocessing@pod-1`` and ``docprocessing@pod-2`` must
|
||||
produce separate samples, not collapse into one duplicate-timestamp
|
||||
series.
|
||||
"""
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-1"})
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-2"})
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-3"})
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
up = collector.collect()[1]
|
||||
assert len(up.samples) == 3
|
||||
hostnames = {s.labels["hostname"] for s in up.samples}
|
||||
assert hostnames == {"pod-1", "pod-2", "pod-3"}
|
||||
assert all(s.labels["worker_type"] == "docprocessing" for s in up.samples)
|
||||
|
||||
def test_reports_dead_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
@@ -175,9 +151,9 @@ class TestWorkerHealthCollector:
|
||||
assert active.samples[0].value == 1
|
||||
|
||||
up = families[1]
|
||||
samples_by_type = {s.labels["worker_type"]: s.value for s in up.samples}
|
||||
assert samples_by_type["primary"] == 1
|
||||
assert samples_by_type["monitoring"] == 0
|
||||
samples_by_name = {s.labels["worker"]: s.value for s in up.samples}
|
||||
assert samples_by_name["primary"] == 1
|
||||
assert samples_by_name["monitoring"] == 0
|
||||
|
||||
def test_empty_monitor_returns_zero(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
|
||||
@@ -58,7 +58,8 @@ SERVICE_ORDER=(
|
||||
validate_template() {
|
||||
local template_file=$1
|
||||
echo "Validating template: $template_file..."
|
||||
if ! aws cloudformation validate-template --template-body file://"$template_file" --region "$AWS_REGION" > /dev/null; then
|
||||
aws cloudformation validate-template --template-body file://"$template_file" --region "$AWS_REGION" > /dev/null
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Validation failed for $template_file. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
@@ -107,15 +108,13 @@ deploy_stack() {
|
||||
fi
|
||||
|
||||
# Create temporary parameters file for this template
|
||||
local temp_params_file
|
||||
temp_params_file=$(create_parameters_from_json "$template_file")
|
||||
local temp_params_file=$(create_parameters_from_json "$template_file")
|
||||
|
||||
# Special handling for SubnetIDs parameter if needed
|
||||
if grep -q "SubnetIDs" "$template_file"; then
|
||||
echo "Template uses SubnetIDs parameter, ensuring it's properly formatted..."
|
||||
# Make sure we're passing SubnetIDs as a comma-separated list
|
||||
local subnet_ids
|
||||
subnet_ids=$(remove_comments "$CONFIG_FILE" | jq -r '.SubnetIDs // empty')
|
||||
local subnet_ids=$(remove_comments "$CONFIG_FILE" | jq -r '.SubnetIDs // empty')
|
||||
if [ -n "$subnet_ids" ]; then
|
||||
echo "Using SubnetIDs from config: $subnet_ids"
|
||||
else
|
||||
@@ -124,13 +123,15 @@ deploy_stack() {
|
||||
fi
|
||||
|
||||
echo "Deploying stack: $stack_name with template: $template_file and generated config from: $CONFIG_FILE..."
|
||||
if ! aws cloudformation deploy \
|
||||
aws cloudformation deploy \
|
||||
--stack-name "$stack_name" \
|
||||
--template-file "$template_file" \
|
||||
--parameter-overrides file://"$temp_params_file" \
|
||||
--capabilities CAPABILITY_IAM CAPABILITY_NAMED_IAM CAPABILITY_AUTO_EXPAND \
|
||||
--region "$AWS_REGION" \
|
||||
--no-cli-auto-prompt > /dev/null; then
|
||||
--no-cli-auto-prompt > /dev/null
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Deployment failed for $stack_name. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -52,9 +52,11 @@ delete_stack() {
|
||||
--region "$AWS_REGION"
|
||||
|
||||
echo "Waiting for stack $stack_name to be deleted..."
|
||||
if aws cloudformation wait stack-delete-complete \
|
||||
aws cloudformation wait stack-delete-complete \
|
||||
--stack-name "$stack_name" \
|
||||
--region "$AWS_REGION"; then
|
||||
--region "$AWS_REGION"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Stack $stack_name deleted successfully."
|
||||
sleep 10
|
||||
else
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user