Compare commits

..

1 Commits

Author SHA1 Message Date
Jamison Lahman
1e69f66705 fix(chat): improve LLM_SOCKET_READ_TIMEOUT user experience 2026-04-17 12:35:02 -07:00
186 changed files with 936 additions and 5196 deletions

View File

@@ -16,13 +16,9 @@
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
],
"containerEnv": {
"MODEL_SERVER_HOST": "inference_model_server",
"OPENSEARCH_HOST": "opensearch",
"POSTGRES_HOST": "relational_db",
"REDIS_HOST": "cache",
"S3_ENDPOINT_URL": "http://minio:9000",
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock",
"VESPA_HOST": "index"
"POSTGRES_HOST": "relational_db",
"REDIS_HOST": "cache"
},
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
"updateRemoteUserUID": false,

View File

@@ -45,7 +45,7 @@ if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
[ -d "$MOUNT_HOME/$item" ] || continue
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
rm -rf "${ACTIVE_HOME:?}/$item"
rm -rf "$ACTIVE_HOME/$item"
fi
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
done

View File

@@ -40,7 +40,6 @@ ALLOWED_DOMAINS=(
"api.anthropic.com"
"api-staging.anthropic.com"
"files.anthropic.com"
"huggingface.co"
"sentry.io"
"update.code.visualstudio.com"
"pypi.org"

View File

@@ -403,7 +403,7 @@ jobs:
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
echo "Certificate imported."
- uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 # ratchet:tauri-apps/tauri-action@action-v0.6.2
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
APPLE_ID: ${{ env.APPLE_ID }}

View File

@@ -42,7 +42,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # zizmor: ignore[cache-poisoning]
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
with:
go-version: ${{ env.GO_VERSION }}
cache-dependency-path: "**/go.sum"

View File

@@ -39,8 +39,6 @@ jobs:
working-directory: ./web
run: npm ci
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
env:
SKIP: ty
with:
prek-version: '0.3.4'
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}

View File

@@ -68,7 +68,6 @@ repos:
pass_filenames: true
files: ^backend/(?!\.venv/|scripts/).*\.py$
- id: uv-run
alias: ty
name: ty
args: ["ty", "check"]
pass_filenames: true
@@ -86,17 +85,6 @@ repos:
hooks:
- id: actionlint
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: 745eface02aef23e168a8afb6b5737818efbea95 # frozen: v0.11.0.1
hooks:
- id: shellcheck
exclude: >-
(?x)^(
backend/scripts/setup_craft_templates\.sh|
deployment/docker_compose/init-letsencrypt\.sh|
deployment/docker_compose/install\.sh
)$
- repo: https://github.com/psf/black
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
hooks:

View File

@@ -1,27 +0,0 @@
"""Add file_id to documents
Revision ID: 91d150c361f6
Revises: a6fcd3d631f9
Create Date: 2026-04-16 15:43:30.314823
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "91d150c361f6"
down_revision = "a6fcd3d631f9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document",
sa.Column("file_id", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("document", "file_id")

View File

@@ -1,48 +0,0 @@
"""replace document sync index with partial index
Replaces the composite index ix_document_sync_status (last_modified, last_synced)
with a partial index ix_document_needs_sync that only indexes rows where
last_modified > last_synced OR last_synced IS NULL.
The old index was never used by the query planner (0 scans in pg_stat_user_indexes)
because Postgres cannot use a B-tree composite index to evaluate a comparison
between two columns in the same row combined with an OR/IS NULL condition.
The partial index makes count_documents_by_needs_sync ~4000x faster for tenants
with no stale documents (161ms -> 0.04ms on a 929K row table) and ~17x faster
for tenants with large backlogs (846ms -> 50ms on a 164K row table).
Revision ID: a6fcd3d631f9
Revises: d129f37b3d87
Create Date: 2026-04-17 16:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a6fcd3d631f9"
down_revision = "d129f37b3d87"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_index(
"ix_document_needs_sync",
"document",
["id"],
postgresql_where=sa.text("last_modified > last_synced OR last_synced IS NULL"),
)
op.drop_index("ix_document_sync_status", table_name="document")
def downgrade() -> None:
op.create_index(
"ix_document_sync_status",
"document",
["last_modified", "last_synced"],
)
op.drop_index("ix_document_needs_sync", table_name="document")

View File

@@ -1,10 +1,8 @@
import time
from typing import cast
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.client import Redis
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.product_gating import get_gated_tenants
@@ -18,56 +16,9 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_tenant_work_gating import cleanup_expired
from onyx.redis.redis_tenant_work_gating import get_active_tenants
from onyx.redis.redis_tenant_work_gating import observe_active_set_size
from onyx.redis.redis_tenant_work_gating import record_full_fanout_cycle
from onyx.redis.redis_tenant_work_gating import record_gate_decision
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
_FULL_FANOUT_TIMESTAMP_KEY_PREFIX = "tenant_work_gating_last_full_fanout_ms"
def _should_bypass_gate_for_full_fanout(
redis_client: Redis, task_name: str, interval_seconds: int
) -> bool:
"""True if at least `interval_seconds` have elapsed since the last
full-fanout bypass for this task. On True, updates the stored timestamp
atomically-enough (it's a best-effort counter, not a lock)."""
key = f"{_FULL_FANOUT_TIMESTAMP_KEY_PREFIX}:{task_name}"
now_ms = int(time.time() * 1000)
threshold_ms = now_ms - (interval_seconds * 1000)
try:
raw = cast(bytes | None, redis_client.get(key))
except Exception:
task_logger.exception(f"full-fanout timestamp read failed: task={task_name}")
# Fail open: treat as "interval elapsed" so we don't skip every
# tenant during a Redis hiccup.
return True
if raw is None:
# First invocation — bypass so the set seeds cleanly.
elapsed = True
else:
try:
last_ms = int(raw.decode())
elapsed = last_ms <= threshold_ms
except ValueError:
elapsed = True
if elapsed:
try:
redis_client.set(key, str(now_ms))
except Exception:
task_logger.exception(
f"full-fanout timestamp write failed: task={task_name}"
)
return elapsed
@shared_task(
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
ignore_result=True,
@@ -81,7 +32,6 @@ def cloud_beat_task_generator(
priority: int = OnyxCeleryPriority.MEDIUM,
expires: int = BEAT_EXPIRES_DEFAULT,
skip_gated: bool = True,
work_gated: bool = False,
) -> bool | None:
"""a lightweight task used to kick off individual beat tasks per tenant."""
time_start = time.monotonic()
@@ -101,56 +51,8 @@ def cloud_beat_task_generator(
tenant_ids: list[str] = []
num_processed_tenants = 0
num_skipped_gated = 0
num_would_skip_work_gate = 0
num_skipped_work_gate = 0
# Tenant-work-gating read path. Resolve once per invocation.
gate_enabled = False
gate_enforce = False
full_fanout_cycle = False
active_tenants: set[str] | None = None
try:
# Gating setup is inside the try block so any exception still
# reaches the finally that releases the beat lock.
if work_gated:
try:
gate_enabled = OnyxRuntime.get_tenant_work_gating_enabled()
gate_enforce = OnyxRuntime.get_tenant_work_gating_enforce()
except Exception:
task_logger.exception("tenant work gating: runtime flag read failed")
gate_enabled = False
if gate_enabled:
redis_failed = False
interval_s = (
OnyxRuntime.get_tenant_work_gating_full_fanout_interval_seconds()
)
full_fanout_cycle = _should_bypass_gate_for_full_fanout(
redis_client, task_name, interval_s
)
if full_fanout_cycle:
record_full_fanout_cycle(task_name)
try:
ttl_s = OnyxRuntime.get_tenant_work_gating_ttl_seconds()
cleanup_expired(ttl_s)
except Exception:
task_logger.exception(
"tenant work gating: cleanup_expired failed"
)
else:
ttl_s = OnyxRuntime.get_tenant_work_gating_ttl_seconds()
active_tenants = get_active_tenants(ttl_s)
if active_tenants is None:
full_fanout_cycle = True
record_full_fanout_cycle(task_name)
redis_failed = True
# Only refresh the gauge when Redis is known-reachable —
# skip the ZCARD if we just failed open due to a Redis error.
if not redis_failed:
observe_active_set_size()
tenant_ids = get_all_tenant_ids()
# Per-task control over whether gated tenants are included. Most periodic tasks
@@ -174,21 +76,6 @@ def cloud_beat_task_generator(
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
continue
# Tenant work gate: if the feature is on, check membership. Skip
# unmarked tenants when enforce=True AND we're not in a full-
# fanout cycle. Always log/emit the shadow counter.
if work_gated and gate_enabled and not full_fanout_cycle:
would_skip = (
active_tenants is not None and tenant_id not in active_tenants
)
if would_skip:
num_would_skip_work_gate += 1
if gate_enforce:
num_skipped_work_gate += 1
record_gate_decision(task_name, skipped=True)
continue
record_gate_decision(task_name, skipped=False)
self.app.send_task(
task_name,
kwargs=dict(
@@ -222,12 +109,6 @@ def cloud_beat_task_generator(
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_skipped_gated={num_skipped_gated} "
f"num_would_skip_work_gate={num_would_skip_work_gate} "
f"num_skipped_work_gate={num_skipped_work_gate} "
f"full_fanout_cycle={full_fanout_cycle} "
f"work_gated={work_gated} "
f"gate_enabled={gate_enabled} "
f"gate_enforce={gate_enforce} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)

View File

@@ -212,7 +212,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
# Tenant-work-gating hook: refresh this tenant's active-set membership
# whenever doc-permission sync has any due cc_pairs to dispatch.
if cc_pair_ids_to_sync:
maybe_mark_tenant_active(tenant_id, caller="doc_permission_sync")
maybe_mark_tenant_active(tenant_id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:

View File

@@ -206,7 +206,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# Tenant-work-gating hook: refresh this tenant's active-set membership
# whenever external-group sync has any due cc_pairs to dispatch.
if cc_pair_ids_to_sync:
maybe_mark_tenant_active(tenant_id, caller="external_group_sync")
maybe_mark_tenant_active(tenant_id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
@@ -506,18 +506,6 @@ def _perform_external_group_sync(
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
# Clean up stale rows from previous cycle BEFORE marking new ones.
# This ensures cleanup always runs regardless of whether the current
# sync succeeds — previously, cleanup only ran at the END of the sync,
# so if the sync failed (e.g. DB connection killed by
# idle_in_transaction_session_timeout during long API calls), stale
# rows would accumulate indefinitely.
logger.info(
f"Removing stale external groups from prior cycle for {source_type} "
f"for cc_pair: {cc_pair_id}"
)
remove_stale_external_groups(db_session, cc_pair_id)
logger.info(
f"Marking old external groups as stale for {source_type} for cc_pair: {cc_pair_id}"
)

View File

@@ -5,7 +5,6 @@ from pydantic import BaseModel
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.access.utils import build_ext_group_name_for_onyx
@@ -78,15 +77,6 @@ def mark_old_external_groups_as_stale(
.where(PublicExternalUserGroup.cc_pair_id == cc_pair_id)
.values(stale=True)
)
# Commit immediately so the transaction closes before potentially long
# external API calls (e.g. Google Drive folder iteration). Without this,
# the DB connection sits idle-in-transaction during API calls and gets
# killed by idle_in_transaction_session_timeout, causing the entire sync
# to fail and stale cleanup to never run.
db_session.commit()
_UPSERT_BATCH_SIZE = 5000
def upsert_external_groups(
@@ -96,102 +86,91 @@ def upsert_external_groups(
source: DocumentSource,
) -> None:
"""
Batch upsert external user groups using INSERT ... ON CONFLICT DO UPDATE.
- For existing rows (same user_id, external_user_group_id, cc_pair_id),
sets stale=False
- For new rows, inserts with stale=False
- Same logic for PublicExternalUserGroup
Performs a true upsert operation for external user groups:
- For existing groups (same user_id, external_user_group_id, cc_pair_id), updates the stale flag to False
- For new groups, inserts them with stale=False
- For public groups, uses upsert logic as well
"""
# If there are no groups to add, return early
if not external_groups:
return
# Collect all emails from all groups to batch-add users at once
all_group_member_emails: set[str] = set()
# collect all emails from all groups to batch add all users at once for efficiency
all_group_member_emails = set()
for external_group in external_groups:
all_group_member_emails.update(external_group.user_emails)
for user_email in external_group.user_emails:
all_group_member_emails.add(user_email)
# Batch add users if they don't exist and get their ids
# batch add users if they don't exist and get their ids
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
# NOTE: this function handles case sensitivity for emails
emails=list(all_group_member_emails),
)
# map emails to ids
email_id_map = {user.email.lower(): user.id for user in all_group_members}
# Build all user-group mappings and public-group mappings
user_group_mappings: list[dict] = []
public_group_mappings: list[dict] = []
# Process each external group
for external_group in external_groups:
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
# Handle user-group mappings
for user_email in external_group.user_emails:
user_id = email_id_map.get(user_email.lower())
if user_id is None:
logger.warning(
f"User in group {external_group.id}"
f" with email {user_email} not found"
f"User in group {external_group.id} with email {user_email} not found"
)
continue
user_group_mappings.append(
{
"user_id": user_id,
"external_user_group_id": external_group_id,
"cc_pair_id": cc_pair_id,
"stale": False,
}
# Check if the user-group mapping already exists
existing_user_group = db_session.scalar(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id,
User__ExternalUserGroupId.external_user_group_id
== external_group_id,
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
)
)
if existing_user_group:
# Update existing record
existing_user_group.stale = False
else:
# Insert new record
new_user_group = User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
stale=False,
)
db_session.add(new_user_group)
# Handle public group if needed
if external_group.gives_anyone_access:
public_group_mappings.append(
{
"external_user_group_id": external_group_id,
"cc_pair_id": cc_pair_id,
"stale": False,
}
# Check if the public group already exists
existing_public_group = db_session.scalar(
select(PublicExternalUserGroup).where(
PublicExternalUserGroup.external_user_group_id == external_group_id,
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
)
)
# Deduplicate to avoid "ON CONFLICT DO UPDATE command cannot affect row
# a second time" when duplicate emails or overlapping groups produce
# identical (user_id, external_user_group_id, cc_pair_id) tuples.
user_group_mappings_deduped = list(
{
(m["user_id"], m["external_user_group_id"], m["cc_pair_id"]): m
for m in user_group_mappings
}.values()
)
# Batch upsert user-group mappings
for i in range(0, len(user_group_mappings_deduped), _UPSERT_BATCH_SIZE):
chunk = user_group_mappings_deduped[i : i + _UPSERT_BATCH_SIZE]
stmt = pg_insert(User__ExternalUserGroupId).values(chunk)
stmt = stmt.on_conflict_do_update(
index_elements=["user_id", "external_user_group_id", "cc_pair_id"],
set_={"stale": False},
)
db_session.execute(stmt)
# Deduplicate public group mappings as well
public_group_mappings_deduped = list(
{
(m["external_user_group_id"], m["cc_pair_id"]): m
for m in public_group_mappings
}.values()
)
# Batch upsert public group mappings
for i in range(0, len(public_group_mappings_deduped), _UPSERT_BATCH_SIZE):
chunk = public_group_mappings_deduped[i : i + _UPSERT_BATCH_SIZE]
stmt = pg_insert(PublicExternalUserGroup).values(chunk)
stmt = stmt.on_conflict_do_update(
index_elements=["external_user_group_id", "cc_pair_id"],
set_={"stale": False},
)
db_session.execute(stmt)
if existing_public_group:
# Update existing record
existing_public_group.stale = False
else:
# Insert new record
new_public_group = PublicExternalUserGroup(
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
stale=False,
)
db_session.add(new_public_group)
db_session.commit()

View File

@@ -27,7 +27,6 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import SENTRY_TRACES_SAMPLE_RATE
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
@@ -102,7 +101,7 @@ def get_model_app() -> FastAPI:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=SENTRY_TRACES_SAMPLE_RATE,
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)

View File

@@ -76,34 +76,24 @@ async def verify_captcha_token(
f"Captcha verification failed: {', '.join(error_codes)}"
)
# Require v3 score. Google's public test secret returns no score
# — that path must not be active in prod since it skips the only
# human-vs-bot signal. A missing score here means captcha is
# misconfigured (test secret in prod, or a v2 response slipped in
# via an action mismatch).
if result.score is None:
logger.warning(
"Captcha verification failed: siteverify returned no score (likely test secret in prod)"
)
raise CaptchaVerificationError(
"Captcha verification failed: missing score"
)
# For reCAPTCHA v3, also check the score
if result.score is not None:
if result.score < RECAPTCHA_SCORE_THRESHOLD:
logger.warning(
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
)
raise CaptchaVerificationError(
"Captcha verification failed: suspicious activity detected"
)
if result.score < RECAPTCHA_SCORE_THRESHOLD:
logger.warning(
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
)
raise CaptchaVerificationError(
"Captcha verification failed: suspicious activity detected"
)
if result.action and result.action != expected_action:
logger.warning(
f"Captcha action mismatch: {result.action} != {expected_action}"
)
raise CaptchaVerificationError(
"Captcha verification failed: action mismatch"
)
# Optionally verify the action matches
if result.action and result.action != expected_action:
logger.warning(
f"Captcha action mismatch: {result.action} != {expected_action}"
)
raise CaptchaVerificationError(
"Captcha verification failed: action mismatch"
)
logger.debug(
f"Captcha verification passed: score={result.score}, action={result.action}"

View File

@@ -30,7 +30,6 @@ from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFI
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
@@ -55,7 +54,6 @@ from onyx.utils.logger import setup_logger
from shared_configs.configs import DEV_LOGGING_ENABLED
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SENTRY_CELERY_TRACES_SAMPLE_RATE
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -70,7 +68,7 @@ if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=SENTRY_CELERY_TRACES_SAMPLE_RATE,
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
@@ -533,26 +531,23 @@ def reset_tenant_id(
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
def wait_for_document_index_or_shutdown() -> None:
"""
Waits for all configured document indices to become ready subject to a
timeout.
def wait_for_vespa_or_shutdown(
sender: Any, # noqa: ARG001
**kwargs: Any, # noqa: ARG001
) -> None: # noqa: ARG001
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
Raises WorkerShutdown if the timeout is reached.
"""
if DISABLE_VECTOR_DB:
logger.info(
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
)
return
if not ONYX_DISABLE_VESPA:
if not wait_for_vespa_with_timeout():
msg = (
"[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
if not wait_for_vespa_with_timeout():
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
if not wait_for_opensearch_with_timeout():

View File

@@ -105,7 +105,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_document_index_or_shutdown()
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -111,7 +111,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_document_index_or_shutdown()
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -97,7 +97,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_document_index_or_shutdown()
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -118,7 +118,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_document_index_or_shutdown()
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -124,7 +124,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_document_index_or_shutdown()
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")

View File

@@ -71,7 +71,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_document_index_or_shutdown()
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -10,7 +10,6 @@ from onyx.configs.app_configs import DISABLE_OPENSEARCH_MIGRATION_TASK
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
@@ -68,7 +67,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"work_gated": True,
},
},
{
@@ -102,7 +100,6 @@ beat_task_templates: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
# Gated tenants may still have connectors awaiting deletion.
"skip_gated": False,
"work_gated": True,
},
},
{
@@ -112,7 +109,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"work_gated": True,
},
},
{
@@ -122,7 +118,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"work_gated": True,
},
},
{
@@ -160,7 +155,6 @@ beat_task_templates: list[dict] = [
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.SANDBOX,
"work_gated": True,
},
},
{
@@ -185,7 +179,6 @@ if ENTERPRISE_EDITION_ENABLED:
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"work_gated": True,
},
},
{
@@ -195,7 +188,6 @@ if ENTERPRISE_EDITION_ENABLED:
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"work_gated": True,
},
},
]
@@ -235,11 +227,7 @@ if SCHEDULED_EVAL_DATASET_NAMES:
)
# Add OpenSearch migration task if enabled.
if (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
and not DISABLE_OPENSEARCH_MIGRATION_TASK
and not ONYX_DISABLE_VESPA
):
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX and not DISABLE_OPENSEARCH_MIGRATION_TASK:
beat_task_templates.append(
{
"name": "migrate-chunks-from-vespa-to-opensearch",
@@ -292,7 +280,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
cloud_task["kwargs"] = {}
cloud_task["kwargs"]["task_name"] = task["task"]
optional_fields = ["queue", "priority", "expires", "skip_gated", "work_gated"]
optional_fields = ["queue", "priority", "expires", "skip_gated"]
for field in optional_fields:
if field in task["options"]:
cloud_task["kwargs"][field] = task["options"][field]
@@ -385,14 +373,12 @@ if not MULTI_TENANT:
]
)
# `skip_gated` and `work_gated` are cloud-only hints consumed by
# `cloud_beat_task_generator`. Strip them before extending the self-hosted
# schedule so they don't leak into apply_async as unrecognised options on
# every fired task message.
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
# it before extending the self-hosted schedule so it doesn't leak into apply_async
# as an unrecognised option on every fired task message.
for _template in beat_task_templates:
_self_hosted_template = copy.deepcopy(_template)
_self_hosted_template["options"].pop("skip_gated", None)
_self_hosted_template["options"].pop("work_gated", None)
tasks_to_schedule.append(_self_hosted_template)

View File

@@ -181,7 +181,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
# nearly every tenant in the active set since most have cc_pairs
# but almost none are actively being deleted on any given cycle.
if has_deleting_cc_pair:
maybe_mark_tenant_active(tenant_id, caller="connector_deletion")
maybe_mark_tenant_active(tenant_id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:

View File

@@ -37,7 +37,6 @@ from onyx.redis.redis_connector import RedisConnector
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_CELERY_TRACES_SAMPLE_RATE
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
@@ -141,7 +140,7 @@ def _docfetching_task(
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=SENTRY_CELERY_TRACES_SAMPLE_RATE,
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)

View File

@@ -1020,7 +1020,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# `tasks_created > 0` here gives us a "real work was done" signal
# rather than just "tenant has a cc_pair somewhere."
if tasks_created > 0:
maybe_mark_tenant_active(tenant_id, caller="check_for_indexing")
maybe_mark_tenant_active(tenant_id)
# 2/3: VALIDATE
# Check for inconsistent index attempts - active attempts without task IDs

View File

@@ -263,7 +263,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# since most tenants have cc_pairs but almost none are due on
# any given cycle.
if prune_dispatched:
maybe_mark_tenant_active(tenant_id, caller="check_for_pruning")
maybe_mark_tenant_active(tenant_id)
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=_get_pruning_block_expiration())
# we want to run this less frequently than the overall task

View File

@@ -153,7 +153,7 @@ def try_generate_stale_document_sync_tasks(
# Tenant-work-gating hook: refresh this tenant's active-set membership
# whenever vespa sync actually has stale docs to dispatch.
maybe_mark_tenant_active(tenant_id, caller="vespa_sync")
maybe_mark_tenant_active(tenant_id)
logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch."

View File

@@ -58,8 +58,6 @@ from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.file_store.staging import build_raw_file_callback
from onyx.file_store.staging import RawFileCallback
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
@@ -92,7 +90,6 @@ def _get_connector_runner(
end_time: datetime,
include_permissions: bool,
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
raw_file_callback: RawFileCallback | None = None,
) -> ConnectorRunner:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@@ -111,7 +108,6 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
raw_file_callback=raw_file_callback,
)
# validate the connector settings
@@ -279,12 +275,6 @@ def run_docfetching_entrypoint(
f"credentials='{credential_id}'"
)
raw_file_callback = build_raw_file_callback(
index_attempt_id=index_attempt_id,
cc_pair_id=connector_credential_pair_id,
tenant_id=tenant_id,
)
connector_document_extraction(
app,
index_attempt_id,
@@ -292,7 +282,6 @@ def run_docfetching_entrypoint(
attempt.search_settings_id,
tenant_id,
callback,
raw_file_callback=raw_file_callback,
)
logger.info(
@@ -312,7 +301,6 @@ def connector_document_extraction(
search_settings_id: int,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
raw_file_callback: RawFileCallback | None = None,
) -> None:
"""Extract documents from connector and queue them for indexing pipeline processing.
@@ -463,7 +451,6 @@ def connector_document_extraction(
start_time=window_start,
end_time=window_end,
include_permissions=should_fetch_permissions_during_indexing,
raw_file_callback=raw_file_callback,
)
# don't use a checkpoint if we're explicitly indexing from

View File

@@ -60,9 +60,7 @@ from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import SearchDoc
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
@@ -76,17 +74,12 @@ from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.search_settings import get_active_search_settings
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
@@ -100,6 +93,7 @@ from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.multi_llm import LLMTimeoutError
from onyx.llm.override_models import LLMOverride
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
@@ -129,7 +123,6 @@ from onyx.tools.tool_constructor import SearchToolConfig
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -252,106 +245,22 @@ def _empty_extracted_context_files() -> ExtractedContextFiles:
)
def _fetch_cached_image_captions(
user_file: UserFile | None,
document_index: DocumentIndex | None,
) -> list[str]:
"""Read image-caption chunks for a user file from the document index.
During indexing, embedded images are summarized via a vision LLM and
those summaries are stored as chunks whose `image_file_id` is set. Reading
them back at chat time avoids re-running vision-LLM calls per turn.
Returns an empty list if the index has no chunks yet (e.g. indexing is
still in flight) or on any fetch failure.
"""
if user_file is None or document_index is None:
return []
try:
chunks = document_index.id_based_retrieval(
chunk_requests=[VespaChunkRequest(document_id=str(user_file.id))],
filters=IndexFilters(
access_control_list=None,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
),
)
except Exception:
logger.warning(
f"Failed to fetch cached captions for user_file {user_file.id}",
exc_info=True,
)
return []
# An image can be spread across multiple chunks; combine by image_file_id
# so a single caption appears once in the context.
combined: dict[str, list[str]] = {}
for chunk in chunks:
if chunk.image_file_id and chunk.content:
combined.setdefault(chunk.image_file_id, []).append(chunk.content)
return [
f"[Image — {image_file_id}]\n" + "\n".join(contents)
for image_file_id, contents in combined.items()
]
def _extract_text_from_in_memory_file(
f: InMemoryChatFile,
user_file: UserFile | None = None,
document_index: DocumentIndex | None = None,
) -> str | None:
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
"""Extract text content from an InMemoryChatFile.
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
ingestion — decode directly.
DOC / CSV / other text types: the content is the original file bytes —
use extract_file_text which handles encoding detection and format parsing.
When image extraction is enabled and the file has embedded images, cached
captions are pulled from the document index and appended to the text.
The index fetch is skipped for files with no embedded images. We do not
re-summarize images inline here — this path is hot and the indexing
pipeline writes chunks atomically, so a missed caption means the file
is mid-indexing and will be picked up on the next turn.
"""
try:
if f.file_type == ChatFileType.PLAIN_TEXT:
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
filename = f.filename or ""
if not get_image_extraction_and_analysis_enabled():
return extract_file_text(
file=io.BytesIO(f.content),
file_name=filename,
break_on_unprocessable=False,
)
extraction = extract_text_and_images(
return extract_file_text(
file=io.BytesIO(f.content),
file_name=filename,
file_name=f.filename or "",
break_on_unprocessable=False,
)
text = extraction.text_content
has_text = bool(text.strip())
has_images = bool(extraction.embedded_images)
if not has_text and not has_images:
# extract_text_and_images has no is_text_file() fallback for
# unknown extensions (.py/.rs/.md without a dedicated handler).
# Defer to the legacy path so those files remain readable.
return extract_file_text(
file=io.BytesIO(f.content),
file_name=filename,
break_on_unprocessable=False,
)
if not has_images:
return text if has_text else None
cached_captions = _fetch_cached_image_captions(user_file, document_index)
parts: list[str] = []
if has_text:
parts.append(text)
parts.extend(cached_captions)
return "\n\n".join(parts).strip() or None
except Exception:
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
return None
@@ -433,23 +342,6 @@ def extract_context_files(
db_session=db_session,
)
# The document index is used at chat time to read cached image captions
# (produced during indexing) so vision-LLM calls don't re-run per turn.
document_index: DocumentIndex | None = None
if not DISABLE_VECTOR_DB and get_image_extraction_and_analysis_enabled():
try:
active_search_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=None,
db_session=db_session,
)
except Exception:
logger.warning(
"Failed to construct document index for caption lookup",
exc_info=True,
)
file_texts: list[str] = []
image_files: list[ChatLoadedFile] = []
file_metadata: list[ContextFileMetadata] = []
@@ -470,9 +362,7 @@ def extract_context_files(
continue
tool_metadata.append(_build_tool_metadata(uf))
elif f.file_type.is_text_file():
text_content = _extract_text_from_in_memory_file(
f, user_file=uf, document_index=document_index
)
text_content = _extract_text_from_in_memory_file(f)
if not text_content:
continue
if not uf:
@@ -1277,6 +1167,32 @@ def _run_models(
else:
if item is _MODEL_DONE:
models_remaining -= 1
elif isinstance(item, LLMTimeoutError):
model_llm = setup.llms[model_idx]
error_msg = (
"The LLM took too long to respond. "
"If you're running a local model, try increasing the "
"LLM_SOCKET_READ_TIMEOUT environment variable "
"(current default: 120 seconds)."
)
stack_trace = "".join(
traceback.format_exception(type(item), item, item.__traceback__)
)
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
stack_trace = stack_trace.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
error_code="CONNECTION_ERROR",
is_retryable=True,
details={
"model": model_llm.config.model_name,
"provider": model_llm.config.model_provider,
"model_index": model_idx,
},
)
elif isinstance(item, Exception):
# Yield a tagged error for this model but keep the other models running.
# Do NOT decrement models_remaining — _run_model's finally always posts

View File

@@ -282,7 +282,6 @@ OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
)
OPENSEARCH_USE_SSL = os.environ.get("OPENSEARCH_USE_SSL", "true").lower() == "true"
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
@@ -328,7 +327,6 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
DISABLE_OPENSEARCH_MIGRATION_TASK = (
os.environ.get("DISABLE_OPENSEARCH_MIGRATION_TASK", "").lower() == "true"
)
ONYX_DISABLE_VESPA = os.environ.get("ONYX_DISABLE_VESPA", "").lower() == "true"
# Whether we should check for and create an index if necessary every time we
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
@@ -868,15 +866,6 @@ MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
)
# Maximum non-empty cells to extract from a single xlsx worksheet. Protects
# from OOM on honestly-huge spreadsheets: memory cost in the extractor is
# roughly proportional to this count. Once exceeded, the scan stops and a
# truncation marker row is appended to the sheet's CSV.
# Peak Memory ~= 100 B * MAX_CELLS
MAX_XLSX_CELLS_PER_SHEET = max(
0, int(os.environ.get("MAX_XLSX_CELLS_PER_SHEET") or 10_000_000)
)
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag

View File

@@ -372,7 +372,6 @@ class FileOrigin(str, Enum):
CONNECTOR_METADATA = "connector_metadata"
GENERATED_REPORT = "generated_report"
INDEXING_CHECKPOINT = "indexing_checkpoint"
INDEXING_STAGING = "indexing_staging"
PLAINTEXT_CACHE = "plaintext_cache"
OTHER = "other"
QUERY_HISTORY_CSV = "query_history_csv"

View File

@@ -3,14 +3,7 @@ from onyx.server.settings.store import load_settings
def get_image_extraction_and_analysis_enabled() -> bool:
"""Return the workspace setting for image extraction/analysis.
The pydantic `Settings` model defaults this field to True, so production
tenants get the feature on by default on first read. The fallback here
stays False so environments where settings cannot be loaded at all
(e.g. unit tests with no DB/Redis) don't trigger downstream vision-LLM
code paths that assume the DB is reachable.
"""
"""Get image extraction and analysis enabled setting from workspace settings or fallback to False"""
try:
settings = load_settings()
if settings.image_extraction_and_analysis_enabled is not None:

View File

@@ -22,7 +22,6 @@ from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.enums import AccessType
from onyx.db.models import Credential
from onyx.file_store.staging import RawFileCallback
from shared_configs.contextvars import get_current_tenant_id
@@ -108,7 +107,6 @@ def instantiate_connector(
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
raw_file_callback: RawFileCallback | None = None,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
@@ -132,9 +130,6 @@ def instantiate_connector(
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
if raw_file_callback is not None:
connector.set_raw_file_callback(raw_file_callback)
return connector

View File

@@ -40,22 +40,6 @@ class GongConnectorCheckpoint(ConnectorCheckpoint):
cursor: str | None = None
# Cached time range — computed once, reused across checkpoint calls
time_range: tuple[str, str] | None = None
# Transcripts whose call details were not yet available from /v2/calls/extensive
# (Gong has a known race where transcript call IDs take time to propagate).
# Keyed by call_id. Retried on subsequent checkpoint invocations.
#
# Invariant: all entries share one resolution session — they're stashed
# together from a single page and share the attempt counter and retry
# deadline. load_from_checkpoint only fetches a new page when this dict
# is empty, so entries from different pages can't mix.
pending_transcripts: dict[str, dict[str, Any]] = {}
# Number of resolution attempts made for pending_transcripts so far.
pending_call_details_attempts: int = 0
# Unix timestamp before which we should not retry pending_transcripts.
# Enforces exponential backoff independent of worker cadence — Gong's
# transcript-ID propagation race can take tens of seconds to minutes,
# longer than typical worker reinvocation intervals.
pending_retry_after: float | None = None
class _TranscriptPage(BaseModel):
@@ -78,15 +62,8 @@ class _CursorExpiredError(Exception):
class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
BASE_URL = "https://api.gong.io"
# Max number of attempts to resolve missing call details across checkpoint
# invocations before giving up and emitting ConnectorFailure.
MAX_CALL_DETAILS_ATTEMPTS = 6
# Base delay for exponential backoff between pending-transcript retry
# attempts. Delay before attempt N (N >= 2) is CALL_DETAILS_DELAY * 2^(N-2)
# seconds (30, 60, 120, 240, 480 = ~15.5min total) — matching the original
# blocking-retry schedule, but enforced via checkpoint deadline rather
# than in-call time.sleep.
CALL_DETAILS_DELAY = 30
CALL_DETAILS_DELAY = 30 # in seconds
# Gong API limit is 3 calls/sec — stay safely under it
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
@@ -210,6 +187,50 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
return call_to_metadata
def _fetch_call_details_with_retry(self, call_ids: list[str]) -> dict[str, Any]:
"""Fetch call details with retry for the Gong API race condition.
The Gong API has a known race where transcript call IDs don't immediately
appear in /v2/calls/extensive. Retries with exponential backoff, only
re-requesting the missing IDs on each attempt.
"""
call_details_map = self._get_call_details_by_ids(call_ids)
if set(call_ids) == set(call_details_map.keys()):
return call_details_map
for attempt in range(2, self.MAX_CALL_DETAILS_ATTEMPTS + 1):
missing_ids = list(set(call_ids) - set(call_details_map.keys()))
logger.warning(
f"_get_call_details_by_ids is missing call id's: current_attempt={attempt - 1} missing_call_ids={missing_ids}"
)
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, attempt - 2)
logger.warning(
f"_get_call_details_by_ids waiting to retry: "
f"wait={wait_seconds}s "
f"current_attempt={attempt - 1} "
f"next_attempt={attempt} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
)
time.sleep(wait_seconds)
# Only re-fetch the missing IDs, merge into existing results
new_details = self._get_call_details_by_ids(missing_ids)
call_details_map.update(new_details)
if set(call_ids) == set(call_details_map.keys()):
return call_details_map
missing_ids = list(set(call_ids) - set(call_details_map.keys()))
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(call_ids)} calls"
)
return call_details_map
@staticmethod
def _parse_parties(parties: list[dict]) -> dict[str, str]:
id_mapping = {}
@@ -292,119 +313,87 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
return start_time, end_time
def _build_document(
self,
transcript: dict[str, Any],
call_details: dict[str, Any],
) -> Document:
"""Build a single Document from a transcript and its resolved call details."""
call_id = transcript["callId"]
call_metadata = call_details["metaData"]
call_time_str = call_metadata["started"]
call_title = call_metadata["title"]
logger.info(
f"Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}"
)
call_parties = cast(list[dict] | None, call_details.get("parties"))
if call_parties is None:
logger.error(f"Couldn't get parties for Call ID: {call_id}")
call_parties = []
id_to_name_map = self._parse_parties(call_parties)
speaker_to_name: dict[str, str] = {}
transcript_text = ""
call_purpose = call_metadata["purpose"]
if call_purpose:
transcript_text += f"Call Description: {call_purpose}\n\n"
contents = transcript["transcript"]
for segment in contents:
speaker_id = segment.get("speakerId", "")
if speaker_id not in speaker_to_name:
if self.hide_user_info:
speaker_to_name[speaker_id] = f"User {len(speaker_to_name) + 1}"
else:
speaker_to_name[speaker_id] = id_to_name_map.get(
speaker_id, "Unknown"
)
speaker_name = speaker_to_name[speaker_id]
sentences = segment.get("sentences", {})
monolog = " ".join([sentence.get("text", "") for sentence in sentences])
transcript_text += f"{speaker_name}: {monolog}\n\n"
return Document(
id=call_id,
sections=[TextSection(link=call_metadata["url"], text=transcript_text)],
source=DocumentSource.GONG,
semantic_identifier=call_title or "Untitled",
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
timezone.utc
),
metadata={"client": call_metadata.get("system")},
)
def _process_transcripts(
self,
transcripts: list[dict[str, Any]],
checkpoint: GongConnectorCheckpoint,
) -> Generator[Document | ConnectorFailure, None, None]:
"""Fetch call details for a page of transcripts and yield resulting
Documents. Transcripts whose call details are missing (Gong race
condition) are stashed into `checkpoint.pending_transcripts` for retry
on a future checkpoint invocation rather than blocking here.
"""
"""Process a batch of transcripts into Documents or ConnectorFailures."""
transcript_call_ids = cast(
list[str],
[t.get("callId") for t in transcripts if t.get("callId")],
)
call_details_map = (
self._get_call_details_by_ids(transcript_call_ids)
if transcript_call_ids
else {}
)
newly_stashed: list[str] = []
call_details_map = self._fetch_call_details_with_retry(transcript_call_ids)
for transcript in transcripts:
call_id = transcript.get("callId")
if not call_id:
logger.error(
"Couldn't get call information for transcript missing callId"
)
if not call_id or call_id not in call_details_map:
logger.error(f"Couldn't get call information for Call ID: {call_id}")
if call_id:
logger.error(
f"Call debug info: call_id={call_id} "
f"call_ids={transcript_call_ids} "
f"call_details_map={call_details_map.keys()}"
)
yield ConnectorFailure(
failed_document=DocumentFailure(document_id="unknown"),
failure_message="Transcript missing callId",
failed_document=DocumentFailure(
document_id=call_id or "unknown",
),
failure_message=f"Couldn't get call information for Call ID: {call_id}",
)
continue
if call_id in call_details_map:
yield self._build_document(transcript, call_details_map[call_id])
continue
call_details = call_details_map[call_id]
call_metadata = call_details["metaData"]
# Details not available yet — stash for retry on next invocation.
checkpoint.pending_transcripts[call_id] = transcript
newly_stashed.append(call_id)
if newly_stashed:
logger.warning(
f"Gong call details not yet available (race condition); "
f"deferring to next checkpoint invocation: "
f"call_ids={newly_stashed}"
call_time_str = call_metadata["started"]
call_title = call_metadata["title"]
logger.info(
f"Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}"
)
call_parties = cast(list[dict] | None, call_details.get("parties"))
if call_parties is None:
logger.error(f"Couldn't get parties for Call ID: {call_id}")
call_parties = []
id_to_name_map = self._parse_parties(call_parties)
speaker_to_name: dict[str, str] = {}
transcript_text = ""
call_purpose = call_metadata["purpose"]
if call_purpose:
transcript_text += f"Call Description: {call_purpose}\n\n"
contents = transcript["transcript"]
for segment in contents:
speaker_id = segment.get("speakerId", "")
if speaker_id not in speaker_to_name:
if self.hide_user_info:
speaker_to_name[speaker_id] = f"User {len(speaker_to_name) + 1}"
else:
speaker_to_name[speaker_id] = id_to_name_map.get(
speaker_id, "Unknown"
)
speaker_name = speaker_to_name[speaker_id]
sentences = segment.get("sentences", {})
monolog = " ".join([sentence.get("text", "") for sentence in sentences])
transcript_text += f"{speaker_name}: {monolog}\n\n"
yield Document(
id=call_id,
sections=[TextSection(link=call_metadata["url"], text=transcript_text)],
source=DocumentSource.GONG,
semantic_identifier=call_title or "Untitled",
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
timezone.utc
),
metadata={"client": call_metadata.get("system")},
)
# First attempt on any newly-stashed transcripts counts as attempt #1.
# pending_call_details_attempts is guaranteed 0 here because
# load_from_checkpoint only reaches _process_transcripts when
# pending_transcripts was empty at entry (see early-return above).
checkpoint.pending_call_details_attempts = 1
checkpoint.pending_retry_after = time.time() + self._next_retry_delay(1)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
combined = (
@@ -443,18 +432,6 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
checkpoint.has_more = True
return checkpoint
# Step 2: Resolve any transcripts stashed by a prior invocation whose
# call details were missing due to Gong's propagation race. Worker
# cadence between checkpoint calls provides the spacing between retry
# attempts — no in-call sleep needed.
if checkpoint.pending_transcripts:
yield from self._resolve_pending_transcripts(checkpoint)
# If pending still exists and we haven't exhausted attempts, defer
# the rest of this invocation — _resolve_pending_transcripts set
# has_more=True for us.
if checkpoint.pending_transcripts:
return checkpoint
workspace_ids = checkpoint.workspace_ids
# If we've exhausted all workspaces, we're done
@@ -473,7 +450,7 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
workspace_id = workspace_ids[checkpoint.workspace_index]
# Step 3: Fetch one page of transcripts
# Step 2: Fetch one page of transcripts
try:
page = self._fetch_transcript_page(
start_datetime=start_time,
@@ -496,102 +473,23 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
checkpoint.has_more = True
return checkpoint
# Step 4: Process transcripts into documents. Missing-details
# transcripts get stashed into checkpoint.pending_transcripts.
# Step 3: Process transcripts into documents
if page.transcripts:
yield from self._process_transcripts(page.transcripts, checkpoint)
yield from self._process_transcripts(page.transcripts)
# Step 5: Update cursor/workspace state
# Step 4: Update checkpoint state
if page.next_cursor:
# More pages in this workspace
checkpoint.cursor = page.next_cursor
checkpoint.has_more = True
else:
# This workspace is exhausted — advance to next
checkpoint.workspace_index += 1
checkpoint.cursor = None
checkpoint.has_more = checkpoint.workspace_index < len(workspace_ids)
# If pending transcripts were stashed this invocation, we still have
# work to do on a future invocation even if pagination is exhausted.
if checkpoint.pending_transcripts:
checkpoint.has_more = True
return checkpoint
def _next_retry_delay(self, attempts_done: int) -> float:
"""Seconds to wait before attempt #(attempts_done + 1).
Matches the original exponential backoff: 30, 60, 120, 240, 480.
"""
return self.CALL_DETAILS_DELAY * pow(2, attempts_done - 1)
def _resolve_pending_transcripts(
self,
checkpoint: GongConnectorCheckpoint,
) -> Generator[Document | ConnectorFailure, None, None]:
"""Attempt to resolve transcripts whose call details were unavailable
in a prior invocation. Mutates checkpoint in place: resolved transcripts
are removed from pending_transcripts; on attempt exhaustion, emits
ConnectorFailure for each unresolved call_id and clears pending state.
If the backoff deadline hasn't elapsed yet, returns without issuing
any API call so the next invocation can try again later.
"""
if (
checkpoint.pending_retry_after is not None
and time.time() < checkpoint.pending_retry_after
):
# Backoff still in effect — defer to a later invocation without
# burning an attempt or an API call.
checkpoint.has_more = True
return
pending_call_ids = list(checkpoint.pending_transcripts.keys())
resolved = self._get_call_details_by_ids(pending_call_ids)
for call_id, details in resolved.items():
transcript = checkpoint.pending_transcripts.pop(call_id, None)
if transcript is None:
continue
yield self._build_document(transcript, details)
if not checkpoint.pending_transcripts:
checkpoint.pending_call_details_attempts = 0
checkpoint.pending_retry_after = None
return
checkpoint.pending_call_details_attempts += 1
logger.warning(
f"Gong call details still missing after "
f"{checkpoint.pending_call_details_attempts}/"
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={list(checkpoint.pending_transcripts.keys())}"
)
if checkpoint.pending_call_details_attempts >= self.MAX_CALL_DETAILS_ATTEMPTS:
logger.error(
f"Giving up on missing Gong call details after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={list(checkpoint.pending_transcripts.keys())}"
)
for call_id in list(checkpoint.pending_transcripts.keys()):
yield ConnectorFailure(
failed_document=DocumentFailure(document_id=call_id),
failure_message=(
f"Couldn't get call details after {self.MAX_CALL_DETAILS_ATTEMPTS} attempts for Call ID: {call_id}"
),
)
checkpoint.pending_transcripts = {}
checkpoint.pending_call_details_attempts = 0
checkpoint.pending_retry_after = None
# has_more is recomputed by the workspace iteration that follows;
# reset to False here so a stale True from a prior invocation
# can't leak out via any future early-return path.
checkpoint.has_more = False
else:
checkpoint.pending_retry_after = time.time() + self._next_retry_delay(
checkpoint.pending_call_details_attempts
)
checkpoint.has_more = True
if __name__ == "__main__":
import os

View File

@@ -578,18 +578,8 @@ class GoogleDriveConnector(
current_id, file.user_email, field_type, failed_folder_ids_by_email
)
if not folder:
# Can't access this folder - stop climbing.
# If the terminal node is a confirmed orphan, backfill all
# intermediate folders into failed_folder_ids_by_email so
# future files short-circuit via _get_folder_metadata's
# cache check instead of re-climbing the whole chain.
if failed_folder_ids_by_email is not None:
for email in {file.user_email, self.primary_admin_email}:
email_failed_ids = failed_folder_ids_by_email.get(email)
if email_failed_ids and current_id in email_failed_ids:
failed_folder_ids_by_email.setdefault(
email, ThreadSafeSet()
).update(set(node_ids_in_walk))
# Can't access this folder - stop climbing
# Don't mark as fully walked since we didn't reach root
break
folder_parent_id = _get_parent_id_from_file(folder)

View File

@@ -379,20 +379,10 @@ def _download_and_extract_sections_basic(
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
or is_tabular_file(file_name)
):
# Google Drive doesn't enforce file extensions, so the filename may not
# end in .xlsx even when the mime type says it's one. Synthesize the
# extension so tabular_file_to_sections dispatches correctly.
tabular_file_name = file_name
if (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
and not is_tabular_file(file_name)
):
tabular_file_name = f"{file_name}.xlsx"
return list(
tabular_file_to_sections(
io.BytesIO(response_call()),
file_name=tabular_file_name,
file_name=file_name,
link=link,
)
)

View File

@@ -167,12 +167,9 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
default_factory=ThreadSafeSet
)
# Maps email → set of folder IDs that email should skip when walking the
# parent chain. Covers two cases:
# 1. Folders where that email confirmed no accessible parent (true orphans).
# 2. Intermediate folders on a path that dead-ended at a confirmed orphan —
# backfilled so future walks short-circuit earlier in the chain.
# In both cases _get_folder_metadata skips the API call and returns None.
# Maps email → set of IDs of folders where that email confirmed no accessible parent.
# Avoids redundant API calls when the same (folder, email) pair is
# encountered again within the same retrieval run.
failed_folder_ids_by_email: ThreadSafeDict[str, ThreadSafeSet[str]] = Field(
default_factory=ThreadSafeDict
)

View File

@@ -15,7 +15,6 @@ from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.file_store.staging import RawFileCallback
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -43,9 +42,6 @@ class NormalizationResult(BaseModel):
class BaseConnector(abc.ABC, Generic[CT]):
REDIS_KEY_PREFIX = "da_connector_data:"
# Optional raw-file persistence hook to save original file
raw_file_callback: RawFileCallback | None = None
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError
@@ -92,15 +88,6 @@ class BaseConnector(abc.ABC, Generic[CT]):
"""Implement if the underlying connector wants to skip/allow image downloading
based on the application level image analysis setting."""
def set_raw_file_callback(self, callback: RawFileCallback) -> None:
"""Inject the per-attempt raw-file persistence callback.
Wired up by the docfetching entrypoint via `instantiate_connector`.
Connectors that don't care about persisting raw bytes can ignore this
— `raw_file_callback` simply stays `None`.
"""
self.raw_file_callback = callback
@classmethod
def normalize_url(cls, url: str) -> "NormalizationResult": # noqa: ARG003
"""Normalize a URL to match the canonical Document.id format used during ingestion.

View File

@@ -62,19 +62,17 @@ def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts: list[str] = []
def _extract(node: dict) -> None:
if node.get("type") == "text":
text = node.get("text", "")
if text:
texts.append(text)
for child in node.get("content", []):
_extract(child)
if adf is not None:
_extract(adf)
# TODO: complete this function
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)

View File

@@ -231,8 +231,6 @@ class DocumentBase(BaseModel):
# Set during docfetching after hierarchy nodes are cached
parent_hierarchy_node_id: int | None = None
file_id: str | None = None
def get_title_for_document_index(
self,
) -> str | None:
@@ -372,7 +370,6 @@ class Document(DocumentBase):
secondary_owners=base.secondary_owners,
title=base.title,
from_ingestion_api=base.from_ingestion_api,
file_id=base.file_id,
)
def __sizeof__(self) -> int:

View File

@@ -75,8 +75,6 @@ from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from onyx.utils.url import SSRFException
from onyx.utils.url import validate_outbound_http_url
logger = setup_logger()
SLIM_BATCH_SIZE = 1000
@@ -983,42 +981,6 @@ class SharepointConnector(
raise ConnectorValidationError(
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
)
try:
validate_outbound_http_url(site_url, https_only=True)
except (SSRFException, ValueError) as e:
raise ConnectorValidationError(
f"Invalid site URL '{site_url}': {e}"
) from e
# Probe RoleAssignments permission — required for permission sync.
# Only runs when credentials have been loaded.
if self.msal_app and self.sp_tenant_domain and self.sites:
try:
token_response = acquire_token_for_rest(
self.msal_app,
self.sp_tenant_domain,
self.sharepoint_domain_suffix,
)
probe_url = (
f"{self.sites[0].rstrip('/')}/_api/web/roleassignments?$top=1"
)
resp = requests.get(
probe_url,
headers={"Authorization": f"Bearer {token_response.accessToken}"},
timeout=10,
)
if resp.status_code in (401, 403):
raise ConnectorValidationError(
"The Azure AD app registration is missing the required SharePoint permission "
"to read role assignments. Please grant 'Sites.FullControl.All' "
"(application permission) in the Azure portal and re-run admin consent."
)
except ConnectorValidationError:
raise
except Exception as e:
logger.warning(
f"RoleAssignments permission probe failed (non-blocking): {e}"
)
def _extract_tenant_domain_from_sites(self) -> str | None:
"""Extract the tenant domain from configured site URLs.
@@ -1914,22 +1876,16 @@ class SharepointConnector(
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
)
try:
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
)
except Exception as e:
logger.warning(
f"Failed to process site page "
f"{site_page.get('webUrl', site_page.get('name', 'Unknown'))}: {e}"
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
)
)
if len(doc_batch) >= SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
@@ -2002,7 +1958,8 @@ class SharepointConnector(
self._graph_client = GraphClient(
_acquire_token_for_graph, environment=self._azure_environment
)
self.sp_tenant_domain = self._resolve_tenant_domain()
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
self.sp_tenant_domain = self._resolve_tenant_domain()
return None
def _get_drive_names_for_site(self, site_url: str) -> list[str]:

View File

@@ -19,7 +19,6 @@ from playwright.sync_api import Playwright
from playwright.sync_api import sync_playwright
from playwright.sync_api import TimeoutError
from requests_oauthlib import OAuth2Session
from typing_extensions import override
from urllib3.exceptions import MaxRetryError
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -33,16 +32,11 @@ from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.sitemap import list_pages_for_site
from onyx.utils.web_content import extract_pdf_text
@@ -61,6 +55,8 @@ class ScrapeSessionContext:
self.visited_links: set[str] = set()
self.content_hashes: set[int] = set()
self.doc_batch: list[Document | HierarchyNode] = []
self.at_least_one_doc: bool = False
self.last_error: str | None = None
self.needs_retry: bool = False
@@ -442,7 +438,7 @@ def _handle_cookies(context: BrowserContext, url: str) -> None:
)
class WebConnector(LoadConnector, SlimConnector):
class WebConnector(LoadConnector):
MAX_RETRIES = 3
def __init__(
@@ -497,14 +493,8 @@ class WebConnector(LoadConnector, SlimConnector):
index: int,
initial_url: str,
session_ctx: ScrapeSessionContext,
slim: bool = False,
) -> ScrapeResult:
"""Returns a ScrapeResult object with a doc and retry flag.
When slim=True, skips scroll, PDF content download, and content extraction.
The bot-detection render wait (5s) fires on CF/403 responses regardless of slim.
networkidle is always awaited so JS-rendered links are discovered correctly.
"""
"""Returns a ScrapeResult object with a doc and retry flag."""
if session_ctx.playwright is None:
raise RuntimeError("scrape_context.playwright is None")
@@ -525,16 +515,7 @@ class WebConnector(LoadConnector, SlimConnector):
is_pdf = is_pdf_resource(initial_url, content_type)
if is_pdf:
if slim:
result.doc = Document(
id=initial_url,
sections=[],
source=DocumentSource.WEB,
semantic_identifier=initial_url,
metadata={},
)
return result
# PDF files are not checked for links
response = requests.get(initial_url, headers=DEFAULT_HEADERS)
page_text, metadata = extract_pdf_text(response.content)
last_modified = response.headers.get("Last-Modified")
@@ -565,20 +546,14 @@ class WebConnector(LoadConnector, SlimConnector):
timeout=30000, # 30 seconds
wait_until="commit", # Wait for navigation to commit
)
# Give the page a moment to start rendering after navigation commits.
# Allows CloudFlare and other bot-detection challenges to complete.
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
# Bot-detection JS challenges (CloudFlare, Imperva, etc.) need a moment
# to start network activity after commit before networkidle is meaningful.
# We detect this via the cf-ray header (CloudFlare) or a 403 response,
# which is the common entry point for JS-challenge-based bot detection.
is_bot_challenge = page_response is not None and (
page_response.header_value("cf-ray") is not None
or page_response.status == 403
)
if is_bot_challenge:
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
# Wait for network activity to settle (handles SPAs, CF challenges, etc.)
# Wait for network activity to settle so SPAs that fetch content
# asynchronously after the initial JS bundle have time to render.
try:
# A bit of extra time to account for long-polling, websockets, etc.
page.wait_for_load_state("networkidle", timeout=PAGE_RENDER_TIMEOUT_MS)
except TimeoutError:
pass
@@ -601,7 +576,7 @@ class WebConnector(LoadConnector, SlimConnector):
session_ctx.visited_links.add(initial_url)
# If we got here, the request was successful
if not slim and self.scroll_before_scraping:
if self.scroll_before_scraping:
scroll_attempts = 0
previous_height = page.evaluate("document.body.scrollHeight")
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
@@ -640,16 +615,6 @@ class WebConnector(LoadConnector, SlimConnector):
result.retry = True
return result
if slim:
result.doc = Document(
id=initial_url,
sections=[],
source=DocumentSource.WEB,
semantic_identifier=initial_url,
metadata={},
)
return result
# after this point, we don't need the caller to retry
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
@@ -701,13 +666,9 @@ class WebConnector(LoadConnector, SlimConnector):
return result
def load_from_state(self, slim: bool = False) -> GenerateDocumentsOutput:
"""Traverses through all pages found on the website and converts them into
documents.
When slim=True, yields SlimDocument objects (URL id only, no content).
Playwright is used in all modes — slim skips content extraction only.
"""
def load_from_state(self) -> GenerateDocumentsOutput:
"""Traverses through all pages found on the website
and converts them into documents"""
if not self.to_visit_list:
raise ValueError("No URLs to visit")
@@ -718,8 +679,6 @@ class WebConnector(LoadConnector, SlimConnector):
session_ctx = ScrapeSessionContext(base_url, self.to_visit_list)
session_ctx.initialize()
batch: list[Document | SlimDocument | HierarchyNode] = []
while session_ctx.to_visit:
initial_url = session_ctx.to_visit.pop()
if initial_url in session_ctx.visited_links:
@@ -734,9 +693,7 @@ class WebConnector(LoadConnector, SlimConnector):
continue
index = len(session_ctx.visited_links)
logger.info(
f"{index}: {'Slim-visiting' if slim else 'Visiting'} {initial_url}"
)
logger.info(f"{index}: Visiting {initial_url}")
# Add retry mechanism with exponential backoff
retry_count = 0
@@ -751,14 +708,12 @@ class WebConnector(LoadConnector, SlimConnector):
time.sleep(delay)
try:
result = self._do_scrape(index, initial_url, session_ctx, slim=slim)
result = self._do_scrape(index, initial_url, session_ctx)
if result.retry:
continue
if result.doc:
batch.append(
SlimDocument(id=result.doc.id) if slim else result.doc
)
session_ctx.doc_batch.append(result.doc)
except Exception as e:
session_ctx.last_error = f"Failed to fetch '{initial_url}': {e}"
logger.exception(session_ctx.last_error)
@@ -769,16 +724,16 @@ class WebConnector(LoadConnector, SlimConnector):
break # success / don't retry
if len(batch) >= self.batch_size:
if len(session_ctx.doc_batch) >= self.batch_size:
session_ctx.initialize()
session_ctx.at_least_one_doc = True
yield batch # ty: ignore[invalid-yield]
batch = []
yield session_ctx.doc_batch
session_ctx.doc_batch = []
if batch:
if session_ctx.doc_batch:
session_ctx.stop()
session_ctx.at_least_one_doc = True
yield batch # ty: ignore[invalid-yield]
yield session_ctx.doc_batch
if not session_ctx.at_least_one_doc:
if session_ctx.last_error:
@@ -787,22 +742,6 @@ class WebConnector(LoadConnector, SlimConnector):
session_ctx.stop()
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""Yields SlimDocuments for all pages reachable from the configured URLs.
Uses the same Playwright crawl as full indexing but skips content extraction,
scroll, and PDF downloads. The 5s render wait fires only on bot-detection
responses (CloudFlare cf-ray header or HTTP 403).
The start/end parameters are ignored — WEB connector has no incremental path.
"""
yield from self.load_from_state(slim=True) # ty: ignore[invalid-yield]
def validate_connector_settings(self) -> None:
# Make sure we have at least one valid URL to check
if not self.to_visit_list:

View File

@@ -244,21 +244,13 @@ def fetch_latest_index_attempts_by_status(
return query.all()
_INTERNAL_ONLY_SOURCES = {
# Used by the ingestion API, not a user-created connector.
DocumentSource.INGESTION_API,
# Backs the user library / build feature, not a connector users filter by.
DocumentSource.CRAFT_FILE,
}
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
distinct_sources = db_session.query(Connector.source).distinct().all()
sources = [
source[0]
for source in distinct_sources
if source[0] not in _INTERNAL_ONLY_SOURCES
if source[0] != DocumentSource.INGESTION_API
]
return sources

View File

@@ -696,7 +696,6 @@ def upsert_documents(
else {}
),
doc_metadata=doc.doc_metadata,
file_id=doc.file_id,
)
)
for doc in seen_documents.values()
@@ -713,7 +712,6 @@ def upsert_documents(
"secondary_owners": insert_stmt.excluded.secondary_owners,
"doc_metadata": insert_stmt.excluded.doc_metadata,
"parent_hierarchy_node_id": insert_stmt.excluded.parent_hierarchy_node_id,
"file_id": insert_stmt.excluded.file_id,
}
if includes_permissions:
# Use COALESCE to preserve existing permissions when new values are NULL.

View File

@@ -62,21 +62,6 @@ def delete_filerecord_by_file_id(
db_session.query(FileRecord).filter_by(file_id=file_id).delete()
def update_filerecord_origin(
file_id: str,
from_origin: FileOrigin,
to_origin: FileOrigin,
db_session: Session,
) -> None:
"""Change a file_record's `file_origin`, filtered on the current origin
so the update is idempotent. Caller owns the commit.
"""
db_session.query(FileRecord).filter(
FileRecord.file_id == file_id,
FileRecord.file_origin == from_origin,
).update({FileRecord.file_origin: to_origin})
def upsert_filerecord(
file_id: str,
display_name: str,

View File

@@ -952,7 +952,6 @@ class Document(Base):
semantic_id: Mapped[str] = mapped_column(NullFilteredString)
# First Section's link
link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
file_id: Mapped[str | None] = mapped_column(String, nullable=True)
# The updated time is also used as a measure of the last successful state of the doc
# pulled from the source (to help skip reindexing already updated docs in case of
@@ -1055,9 +1054,9 @@ class Document(Base):
__table_args__ = (
Index(
"ix_document_needs_sync",
"id",
postgresql_where=text("last_modified > last_synced OR last_synced IS NULL"),
"ix_document_sync_status",
last_modified,
last_synced,
),
)

View File

@@ -20,7 +20,6 @@ from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
from onyx.db.enums import OpenSearchDocumentMigrationStatus
from onyx.db.models import Document
from onyx.db.models import OpenSearchDocumentMigrationRecord
@@ -413,11 +412,7 @@ def get_opensearch_retrieval_state(
If the tenant migration record is not found, defaults to
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX.
If ONYX_DISABLE_VESPA is True, always returns True.
"""
if ONYX_DISABLE_VESPA:
return True
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
return ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX

View File

@@ -129,13 +129,9 @@ def update_sync_record_status(
"""
sync_record = fetch_latest_sync_record(db_session, entity_id, sync_type)
if sync_record is None:
logger.warning(
f"No sync record found for entity_id={entity_id} "
f"sync_type={sync_type} — skipping status update. "
f"This typically means the record was never created "
f"(insert_sync_record failed silently) or was cleaned up."
raise ValueError(
f"No sync record found for entity_id={entity_id} sync_type={sync_type}"
)
return
sync_record.sync_status = sync_status
if num_docs_synced is not None:

View File

@@ -2,17 +2,11 @@ import datetime
from uuid import UUID
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import FileRecord
from onyx.db.models import Persona
from onyx.db.models import Project__UserFile
from onyx.db.models import UserFile
@@ -114,73 +108,13 @@ def update_last_accessed_at_for_user_files(
db_session.commit()
def get_file_id_by_user_file_id(
user_file_id: str, user_id: UUID, db_session: Session
) -> str | None:
user_file = (
db_session.query(UserFile)
.filter(UserFile.id == user_file_id, UserFile.user_id == user_id)
.first()
)
def get_file_id_by_user_file_id(user_file_id: str, db_session: Session) -> str | None:
user_file = db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
if user_file:
return user_file.file_id
return None
def user_can_access_chat_file(file_id: str, user_id: UUID, db_session: Session) -> bool:
"""Return True if `user_id` is allowed to read the raw `file_id` served by
`GET /chat/file/{file_id}`. Access is granted when any of:
- The `file_id` is the storage id of a `UserFile` owned by the user.
- The `file_id` is a persona avatar (`Persona.uploaded_image_id`); avatars
are visible to any authenticated user.
- The `file_id` appears in a `ChatMessage.files` descriptor of a chat
session the user owns or a session publicly shared via
`ChatSessionSharedStatus.PUBLIC`.
"""
owns_user_file = db_session.query(
select(UserFile.id)
.where(UserFile.file_id == file_id, UserFile.user_id == user_id)
.exists()
).scalar()
if owns_user_file:
return True
# TODO: move persona avatars to a dedicated endpoint (e.g.
# /assistants/{id}/avatar) so this branch can be removed. /chat/file is
# currently overloaded with multiple asset classes (user files, chat
# attachments, tool outputs, avatars), forcing this access-check fan-out.
#
# Restrict the avatar path to CHAT_UPLOAD-origin files so an attacker
# cannot bind another user's USER_FILE (or any other origin) to their
# own persona and read it through this check.
is_persona_avatar = db_session.query(
select(Persona.id)
.join(FileRecord, FileRecord.file_id == Persona.uploaded_image_id)
.where(
Persona.uploaded_image_id == file_id,
FileRecord.file_origin == FileOrigin.CHAT_UPLOAD,
)
.exists()
).scalar()
if is_persona_avatar:
return True
chat_file_stmt = (
select(ChatMessage.id)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.where(ChatMessage.files.op("@>")([{"id": file_id}]))
.where(
or_(
ChatSession.user_id == user_id,
ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC,
)
)
.limit(1)
)
return db_session.execute(chat_file_stmt).first() is not None
def get_file_ids_by_user_file_ids(
user_file_ids: list[UUID], db_session: Session
) -> list[str]:

View File

@@ -3,7 +3,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
from onyx.db.models import SearchSettings
from onyx.db.opensearch_migration import get_opensearch_retrieval_state
from onyx.document_index.disabled import DisabledDocumentIndex
@@ -49,11 +48,6 @@ def get_default_document_index(
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if ONYX_DISABLE_VESPA:
if not opensearch_retrieval_enabled:
raise ValueError(
"BUG: ONYX_DISABLE_VESPA is set but opensearch_retrieval_enabled is not set."
)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
@@ -125,32 +119,21 @@ def get_all_document_indices(
)
]
result: list[DocumentIndex] = []
if ONYX_DISABLE_VESPA:
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
raise ValueError(
"ONYX_DISABLE_VESPA is set but ENABLE_OPENSEARCH_INDEXING_FOR_ONYX is not set."
)
else:
vespa_document_index = VespaIndex(
index_name=search_settings.index_name,
secondary_index_name=(
secondary_search_settings.index_name
if secondary_search_settings
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=(
secondary_search_settings.large_chunks_enabled
if secondary_search_settings
else None
),
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
result.append(vespa_document_index)
vespa_document_index = VespaIndex(
index_name=search_settings.index_name,
secondary_index_name=(
secondary_search_settings.index_name if secondary_search_settings else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=(
secondary_search_settings.large_chunks_enabled
if secondary_search_settings
else None
),
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
@@ -186,6 +169,7 @@ def get_all_document_indices(
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
result: list[DocumentIndex] = [vespa_document_index]
if opensearch_document_index:
result.append(opensearch_document_index)
return result

View File

@@ -98,9 +98,6 @@ class DocumentMetadata:
# The resolved database ID of the parent hierarchy node (folder/container)
parent_hierarchy_node_id: int | None = None
# Opt-in pointer to the persisted raw file for this document (file_store id).
file_id: str | None = None
@dataclass
class VespaDocumentFields:

View File

@@ -17,7 +17,6 @@ from onyx.configs.app_configs import OPENSEARCH_ADMIN_PASSWORD
from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
from onyx.configs.app_configs import OPENSEARCH_HOST
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
from onyx.configs.app_configs import OPENSEARCH_USE_SSL
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import OpenSearchSearchType
from onyx.document_index.opensearch.schema import DocumentChunk
@@ -133,7 +132,7 @@ class OpenSearchClient(AbstractContextManager):
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
use_ssl: bool = OPENSEARCH_USE_SSL,
use_ssl: bool = True,
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
@@ -303,7 +302,7 @@ class OpenSearchIndexClient(OpenSearchClient):
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
use_ssl: bool = OPENSEARCH_USE_SSL,
use_ssl: bool = True,
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
@@ -508,55 +507,8 @@ class OpenSearchIndexClient(OpenSearchClient):
Raises:
Exception: There was an error updating the settings of the index.
"""
logger.debug(f"Updating settings of index {self._index_name} with {settings}.")
response = self._client.indices.put_settings(
index=self._index_name, body=settings
)
if not response.get("acknowledged", False):
raise RuntimeError(
f"Failed to update settings of index {self._index_name}."
)
logger.debug(f"Settings of index {self._index_name} updated successfully.")
@log_function_time(print_only=True, debug_only=True)
def get_settings(self) -> dict[str, Any]:
"""Gets the settings of the index.
Returns:
The settings of the index.
Raises:
Exception: There was an error getting the settings of the index.
"""
logger.debug(f"Getting settings of index {self._index_name}.")
response = self._client.indices.get_settings(index=self._index_name)
return response[self._index_name]["settings"]
@log_function_time(print_only=True, debug_only=True)
def open_index(self) -> None:
"""Opens the index.
Raises:
Exception: There was an error opening the index.
"""
logger.debug(f"Opening index {self._index_name}.")
response = self._client.indices.open(index=self._index_name)
if not response.get("acknowledged", False):
raise RuntimeError(f"Failed to open index {self._index_name}.")
logger.debug(f"Index {self._index_name} opened successfully.")
@log_function_time(print_only=True, debug_only=True)
def close_index(self) -> None:
"""Closes the index.
Raises:
Exception: There was an error closing the index.
"""
logger.debug(f"Closing index {self._index_name}.")
response = self._client.indices.close(index=self._index_name)
if not response.get("acknowledged", False):
raise RuntimeError(f"Failed to close index {self._index_name}.")
logger.debug(f"Index {self._index_name} closed successfully.")
# TODO(andrei): Implement this.
raise NotImplementedError
@log_function_time(
print_only=True,

View File

@@ -12,7 +12,6 @@ from email.parser import Parser as EmailParser
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import cast
from typing import IO
from typing import NamedTuple
from typing import Optional
@@ -21,11 +20,10 @@ from zipfile import BadZipFile
import chardet
import openpyxl
from openpyxl.worksheet._read_only import ReadOnlyWorksheet
from openpyxl.worksheet.worksheet import Worksheet
from PIL import Image
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.app_configs import MAX_XLSX_CELLS_PER_SHEET
from onyx.configs.constants import ONYX_METADATA_FILENAME
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.file_processing.file_types import OnyxFileExtensions
@@ -50,8 +48,6 @@ KNOWN_OPENPYXL_BUGS = [
"File contains no valid workbook part",
"Unable to read workbook: could not read stylesheet from None",
"Colors must be aRGB hex values",
"Max value is",
"There is no item named",
]
@@ -368,40 +364,6 @@ def extract_docx_images(docx_bytes: IO[Any]) -> Iterator[tuple[bytes, str]]:
logger.exception("Failed to extract all docx images")
def count_docx_embedded_images(file: IO[Any], cap: int) -> int:
"""Return the number of embedded images in a docx, short-circuiting at cap+1.
Mirrors count_pdf_embedded_images so upload validation can apply the same
per-file/per-batch caps. Returns a value > cap once the count exceeds the
cap so callers do not iterate every media entry just to report a number.
Always restores the file pointer to its original position before returning.
"""
try:
start_pos = file.tell()
except Exception:
start_pos = None
try:
if start_pos is not None:
file.seek(0)
count = 0
with zipfile.ZipFile(file) as z:
for name in z.namelist():
if name.startswith("word/media/"):
count += 1
if count > cap:
return count
return count
except Exception:
logger.warning("Failed to count embedded images in docx", exc_info=True)
return 0
finally:
if start_pos is not None:
try:
file.seek(start_pos)
except Exception:
pass
def read_docx_file(
file: IO[Any],
file_name: str = "",
@@ -484,83 +446,104 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
return presentation.markdown
def _columns_to_keep(col_has_data: bytearray, max_empty: int) -> list[int]:
"""Keep non-empty columns, plus runs of up to ``max_empty`` empty columns
between them. Trailing empty columns are dropped."""
def _worksheet_to_matrix(
worksheet: Worksheet,
) -> list[list[str]]:
"""
Converts a singular worksheet to a matrix of values.
Rows are padded to a uniform width. In openpyxl's read_only mode,
iter_rows can yield rows of differing lengths (trailing empty cells
are sometimes omitted), and downstream column cleanup assumes a
rectangular matrix.
"""
rows: list[list[str]] = []
max_len = 0
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
row = ["" if cell is None else str(cell) for cell in worksheet_row]
if len(row) > max_len:
max_len = len(row)
rows.append(row)
for row in rows:
if len(row) < max_len:
row.extend([""] * (max_len - len(row)))
return rows
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
"""
Cleans a worksheet matrix by removing rows if there are N consecutive empty
rows and removing cols if there are M consecutive empty columns
"""
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
MAX_EMPTY_COLS = 2
# Row cleanup
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
if not matrix:
return matrix
# Column cleanup — determine which columns to keep without transposing.
num_cols = len(matrix[0])
keep_cols = _columns_to_keep(matrix, num_cols, max_empty=MAX_EMPTY_COLS)
if len(keep_cols) < num_cols:
matrix = [[row[c] for c in keep_cols] for row in matrix]
return matrix
def _columns_to_keep(
matrix: list[list[str]], num_cols: int, max_empty: int
) -> list[int]:
"""Return the indices of columns to keep after removing empty-column runs.
Uses the same logic as ``_remove_empty_runs`` but operates on column
indices so no transpose is needed.
"""
kept: list[int] = []
empty_buffer: list[int] = []
for c, has in enumerate(col_has_data):
if has:
kept.extend(empty_buffer[:max_empty])
kept.append(c)
empty_buffer = []
for col_idx in range(num_cols):
col_is_empty = all(not row[col_idx] for row in matrix)
if col_is_empty:
empty_buffer.append(col_idx)
else:
empty_buffer.append(c)
kept.extend(empty_buffer[:max_empty])
kept.append(col_idx)
empty_buffer = []
return kept
def _sheet_to_csv(rows: Iterator[tuple[Any, ...]]) -> str:
"""Stream worksheet rows into CSV text without materializing a dense matrix.
def _remove_empty_runs(
rows: list[list[str]],
max_empty: int,
) -> list[list[str]]:
"""Removes entire runs of empty rows when the run length exceeds max_empty.
Empty rows are never stored. Column occupancy is tracked as a ``bytearray``
bitmap so column trimming needs no transpose or copy. Runs of empty
rows/columns longer than 2 are collapsed; shorter runs are preserved.
Scanning stops once ``MAX_XLSX_CELLS_PER_SHEET`` non-empty cells have been
seen; the output gets a truncation marker row appended so downstream
indexing sees that the sheet was cut off.
Leading empty runs are capped to max_empty, just like interior runs.
Trailing empty rows are always dropped since there is no subsequent
non-empty row to flush them.
"""
MAX_EMPTY_ROWS_IN_OUTPUT = 2
MAX_EMPTY_COLS_IN_OUTPUT = 2
TRUNCATION_MARKER = "[truncated: sheet exceeded cell limit]"
result: list[list[str]] = []
empty_buffer: list[list[str]] = []
non_empty_rows: list[tuple[int, list[str]]] = []
col_has_data = bytearray()
total_non_empty = 0
truncated = False
for row in rows:
# Check if empty
if not any(row):
if len(empty_buffer) < max_empty:
empty_buffer.append(row)
else:
# Add upto max empty rows onto the result - that's what we allow
result.extend(empty_buffer[:max_empty])
# Add the new non-empty row
result.append(row)
empty_buffer = []
for row_idx, row_vals in enumerate(rows):
# Fast-reject empty rows before allocating a list of "".
if not any(v is not None and v != "" for v in row_vals):
continue
cells = ["" if v is None else str(v) for v in row_vals]
non_empty_rows.append((row_idx, cells))
if len(cells) > len(col_has_data):
col_has_data.extend(b"\x00" * (len(cells) - len(col_has_data)))
for i, v in enumerate(cells):
if v:
col_has_data[i] = 1
total_non_empty += 1
if total_non_empty > MAX_XLSX_CELLS_PER_SHEET:
truncated = True
break
if not non_empty_rows:
return ""
keep_cols = _columns_to_keep(col_has_data, MAX_EMPTY_COLS_IN_OUTPUT)
if not keep_cols:
return ""
buf = io.StringIO()
writer = csv.writer(buf, lineterminator="\n")
blank_row = [""] * len(keep_cols)
last_idx = -1
for row_idx, cells in non_empty_rows:
gap = row_idx - last_idx - 1
if gap > 0:
for _ in range(min(gap, MAX_EMPTY_ROWS_IN_OUTPUT)):
writer.writerow(blank_row)
writer.writerow([cells[c] if c < len(cells) else "" for c in keep_cols])
last_idx = row_idx
if truncated:
writer.writerow([TRUNCATION_MARKER])
return buf.getvalue().rstrip("\n")
return result
def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str, str]]:
@@ -588,24 +571,20 @@ def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str,
raise
sheets: list[tuple[str, str]] = []
try:
for sheet in workbook.worksheets:
# Declared dimensions can be different to what is actually there
ro_sheet = cast(ReadOnlyWorksheet, sheet)
ro_sheet.reset_dimensions()
csv_text = _sheet_to_csv(ro_sheet.iter_rows(values_only=True))
sheets.append((csv_text.strip(), ro_sheet.title))
finally:
workbook.close()
for sheet in workbook.worksheets:
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
buf = io.StringIO()
writer = csv.writer(buf, lineterminator="\n")
writer.writerows(sheet_matrix)
csv_text = buf.getvalue().rstrip("\n")
if csv_text.strip():
sheets.append((csv_text, sheet.title))
return sheets
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
sheets = xlsx_sheet_extraction(file, file_name)
return TEXT_SECTION_SEPARATOR.join(
csv_text for csv_text, _title in sheets if csv_text
)
return TEXT_SECTION_SEPARATOR.join(csv_text for csv_text, _title in sheets)
def eml_to_text(file: IO[Any]) -> str:

View File

@@ -1,76 +0,0 @@
from collections.abc import Callable
from typing import Any
from typing import IO
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.db.file_record import update_filerecord_origin
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
# (content, content_type) -> file_id
RawFileCallback = Callable[[IO[bytes], str], str]
def stage_raw_file(
content: IO,
content_type: str,
*,
metadata: dict[str, Any],
) -> str:
"""Persist raw bytes to the file store with FileOrigin.INDEXING_STAGING.
`metadata` is attached to the file_record so that downstream promotion
(in docprocessing) and orphan reaping (TTL janitor) can locate the file
by its originating context.
"""
file_store = get_default_file_store()
file_id = file_store.save_file(
content=content,
display_name=None,
file_origin=FileOrigin.INDEXING_STAGING,
file_type=content_type,
file_metadata=metadata,
)
return file_id
def build_raw_file_callback(
*,
index_attempt_id: int,
cc_pair_id: int,
tenant_id: str,
) -> RawFileCallback:
"""Build a per-attempt callback that connectors can invoke to opt in to
raw-file persistence. The closure binds the attempt-level context as the
staging metadata so the connector only needs to pass per-call info
(bytes, content_type) and gets back a file_id to attach to its Document.
"""
metadata: dict[str, Any] = {
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
}
def _callback(content: IO[bytes], content_type: str) -> str:
return stage_raw_file(
content=content,
content_type=content_type,
metadata=metadata,
)
return _callback
def promote_staged_file(db_session: Session, file_id: str) -> None:
"""Mark a previously-staged file as `FileOrigin.CONNECTOR`."""
update_filerecord_origin(
file_id=file_id,
from_origin=FileOrigin.INDEXING_STAGING,
to_origin=FileOrigin.CONNECTOR,
db_session=db_session,
)

View File

@@ -30,7 +30,6 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.connectors.models import SectionType
from onyx.connectors.models import TextSection
from onyx.db.document import get_documents_by_ids
from onyx.db.document import upsert_document_by_connector_credential_pair
@@ -50,7 +49,6 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.staging import promote_staged_file
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
@@ -156,7 +154,6 @@ def _upsert_documents_in_db(
doc_metadata=doc.doc_metadata,
# parent_hierarchy_node_id is resolved in docfetching using Redis cache
parent_hierarchy_node_id=doc.parent_hierarchy_node_id,
file_id=doc.file_id,
)
document_metadata_list.append(db_doc_metadata)
@@ -367,45 +364,6 @@ def index_doc_batch_with_handler(
return index_pipeline_result
def _promote_new_staged_files(
documents: list[Document],
previous_file_ids: dict[str, str],
db_session: Session,
) -> None:
"""Queue STAGING → CONNECTOR origin flips for every new file_id in the batch.
Intended to run immediately before `_upsert_documents_in_db` so the origin
flip lands in the same commit as the `Document.file_id` write. Does not
commit — the caller's next commit flushes these UPDATEs.
"""
for doc in documents:
new_file_id = doc.file_id
if new_file_id is None or new_file_id == previous_file_ids.get(doc.id):
continue
promote_staged_file(db_session=db_session, file_id=new_file_id)
def _delete_replaced_files(
documents: list[Document],
previous_file_ids: dict[str, str],
) -> None:
"""Best-effort blob deletes for file_ids replaced in this batch.
Must run AFTER `Document.file_id` has been committed to the new
file_id.
"""
file_store = get_default_file_store()
for doc in documents:
new_file_id = doc.file_id
old_file_id = previous_file_ids.get(doc.id)
if old_file_id is None or old_file_id == new_file_id:
continue
try:
file_store.delete_file(old_file_id, error_on_missing=False)
except Exception:
logger.exception(f"Failed to delete replaced file_id={old_file_id}.")
def index_doc_batch_prepare(
documents: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
@@ -424,11 +382,6 @@ def index_doc_batch_prepare(
document_ids=document_ids,
)
# Capture previous file_ids BEFORE any writes so we know what to reap.
previous_file_ids: dict[str, str] = {
db_doc.id: db_doc.file_id for db_doc in db_docs if db_doc.file_id is not None
}
updatable_docs = (
get_doc_ids_to_update(documents=documents, db_docs=db_docs)
if not ignore_time_skip
@@ -446,24 +399,11 @@ def index_doc_batch_prepare(
# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
if updatable_docs:
# Queue the STAGING → CONNECTOR origin flips BEFORE the Document upsert
# so `upsert_documents`' commit flushes Document.file_id and the origin
# flip atomically
_promote_new_staged_files(
documents=updatable_docs,
previous_file_ids=previous_file_ids,
db_session=db_session,
)
_upsert_documents_in_db(
documents=updatable_docs,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
# Blob deletes run only after Document.file_id is durable.
_delete_replaced_files(
documents=updatable_docs,
previous_file_ids=previous_file_ids,
)
logger.info(
f"Upserted {len(updatable_docs)} changed docs out of {len(documents)} total docs into the DB"
@@ -590,15 +530,8 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
Returns:
List of IndexingDocument objects with processed_sections as list[Section]
"""
# Check if image extraction and analysis is enabled before trying to get a vision LLM.
# Use section.type rather than isinstance because sections can round-trip
# through pydantic as base Section instances (not the concrete subclass).
has_image_section = any(
section.type == SectionType.IMAGE
for document in documents
for section in document.sections
)
if not get_image_extraction_and_analysis_enabled() or not has_image_section:
# Check if image extraction and analysis is enabled before trying to get a vision LLM
if not get_image_extraction_and_analysis_enabled():
llm = None
else:
# Only get the vision LLM if image processing is enabled

View File

@@ -290,7 +290,11 @@ def litellm_exception_to_error_msg(
error_code = "BUDGET_EXCEEDED"
is_retryable = False
elif isinstance(core_exception, Timeout):
error_msg = "Request timed out: The operation took too long to complete. Please try again."
error_msg = (
"The LLM took too long to respond. "
"If you're running a local model, try increasing the "
"LLM_SOCKET_READ_TIMEOUT environment variable (current default: 120 seconds)."
)
error_code = "CONNECTION_ERROR"
is_retryable = True
elif isinstance(core_exception, APIError):

View File

@@ -166,7 +166,6 @@ from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import SENTRY_TRACES_SAMPLE_RATE
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
warnings.filterwarnings(
@@ -440,7 +439,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=SENTRY_TRACES_SAMPLE_RATE,
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)

View File

@@ -19,14 +19,9 @@ from onyx.configs.app_configs import MCP_SERVER_CORS_ORIGINS
from onyx.mcp_server.auth import OnyxTokenVerifier
from onyx.mcp_server.utils import shutdown_http_client
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
logger = setup_logger()
# Initialize EE flag at module import so it's set regardless of the entry point
# (python -m onyx.mcp_server_main, uvicorn onyx.mcp_server.api:mcp_app, etc.).
set_is_ee_based_on_env_variable()
logger.info("Creating Onyx MCP Server...")
mcp_server = FastMCP(

View File

@@ -1,5 +1,4 @@
"""Resource registrations for the Onyx MCP server."""
# Import resource modules so decorators execute when the package loads.
from onyx.mcp_server.resources import document_sets # noqa: F401
from onyx.mcp_server.resources import indexed_sources # noqa: F401

View File

@@ -1,41 +0,0 @@
"""Resource exposing document sets available to the current user."""
from __future__ import annotations
import json
from onyx.mcp_server.api import mcp_server
from onyx.mcp_server.utils import get_accessible_document_sets
from onyx.mcp_server.utils import require_access_token
from onyx.utils.logger import setup_logger
logger = setup_logger()
@mcp_server.resource(
"resource://document_sets",
name="document_sets",
description=(
"Enumerate the Document Sets accessible to the current user. Use the "
"returned `name` values with the `document_set_names` filter of the "
"`search_indexed_documents` tool to scope searches to a specific set."
),
mime_type="application/json",
)
async def document_sets_resource() -> str:
"""Return the list of document sets the user can filter searches by."""
access_token = require_access_token()
document_sets = sorted(
await get_accessible_document_sets(access_token), key=lambda entry: entry.name
)
logger.info(
"Onyx MCP Server: document_sets resource returning %s entries",
len(document_sets),
)
# FastMCP 3.2+ requires str/bytes/list[ResourceContent] — it no longer
# auto-serializes; serialize to JSON ourselves.
return json.dumps([entry.model_dump(mode="json") for entry in document_sets])

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import json
from typing import Any
from onyx.mcp_server.api import mcp_server
from onyx.mcp_server.utils import get_indexed_sources
@@ -21,7 +21,7 @@ logger = setup_logger()
),
mime_type="application/json",
)
async def indexed_sources_resource() -> str:
async def indexed_sources_resource() -> dict[str, Any]:
"""Return the list of indexed source types for search filtering."""
access_token = require_access_token()
@@ -33,6 +33,6 @@ async def indexed_sources_resource() -> str:
len(sources),
)
# FastMCP 3.2+ requires str/bytes/list[ResourceContent] — it no longer
# auto-serializes; serialize to JSON ourselves.
return json.dumps(sorted(sources))
return {
"indexed_sources": sorted(sources),
}

View File

@@ -4,23 +4,12 @@ from datetime import datetime
from typing import Any
import httpx
from fastmcp.server.auth.auth import AccessToken
from pydantic import BaseModel
from onyx.chat.models import ChatFullResponse
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import SearchDoc
from onyx.mcp_server.api import mcp_server
from onyx.mcp_server.utils import get_http_client
from onyx.mcp_server.utils import get_indexed_sources
from onyx.mcp_server.utils import require_access_token
from onyx.server.features.web_search.models import OpenUrlsToolRequest
from onyx.server.features.web_search.models import OpenUrlsToolResponse
from onyx.server.features.web_search.models import WebSearchToolRequest
from onyx.server.features.web_search.models import WebSearchToolResponse
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
from onyx.utils.variable_functionality import global_version
@@ -28,43 +17,6 @@ from onyx.utils.variable_functionality import global_version
logger = setup_logger()
# CE search falls through to the chat endpoint, which invokes an LLM — the
# default 60s client timeout is not enough for a real RAG-backed response.
_CE_SEARCH_TIMEOUT_SECONDS = 300.0
async def _post_model(
url: str,
body: BaseModel,
access_token: AccessToken,
timeout: float | None = None,
) -> httpx.Response:
"""POST a Pydantic model as JSON to the Onyx backend."""
return await get_http_client().post(
url,
content=body.model_dump_json(exclude_unset=True),
headers={
"Authorization": f"Bearer {access_token.token}",
"Content-Type": "application/json",
},
timeout=timeout if timeout is not None else httpx.USE_CLIENT_DEFAULT,
)
def _project_doc(doc: SearchDoc, content: str | None) -> dict[str, Any]:
"""Project a backend search doc into the MCP wire shape.
Accepts SearchDocWithContent (EE) too since it extends SearchDoc.
"""
return {
"semantic_identifier": doc.semantic_identifier,
"content": content,
"source_type": doc.source_type.value,
"link": doc.link,
"score": doc.score,
}
def _extract_error_detail(response: httpx.Response) -> str:
"""Extract a human-readable error message from a failed backend response.
@@ -84,7 +36,6 @@ def _extract_error_detail(response: httpx.Response) -> str:
async def search_indexed_documents(
query: str,
source_types: list[str] | None = None,
document_set_names: list[str] | None = None,
time_cutoff: str | None = None,
limit: int = 10,
) -> dict[str, Any]:
@@ -102,10 +53,6 @@ async def search_indexed_documents(
In EE mode, the dedicated search endpoint is used instead.
To find a list of available sources, use the `indexed_sources` resource.
`document_set_names` restricts results to documents belonging to the named
Document Sets — useful for scoping queries to a curated subset of the
knowledge base (e.g. to isolate knowledge between agents). Use the
`document_sets` resource to discover accessible set names.
Returns chunks of text as search results with snippets, scores, and metadata.
Example usage:
@@ -113,23 +60,15 @@ async def search_indexed_documents(
{
"query": "What is the latest status of PROJ-1234 and what is the next development item?",
"source_types": ["jira", "google_drive", "github"],
"document_set_names": ["Engineering Wiki"],
"time_cutoff": "2025-11-24T00:00:00Z",
"limit": 10,
}
```
"""
logger.info(
f"Onyx MCP Server: document search: query='{query}', sources={source_types}, "
f"document_sets={document_set_names}, limit={limit}"
f"Onyx MCP Server: document search: query='{query}', sources={source_types}, limit={limit}"
)
# Normalize empty list inputs to None so downstream filter construction is
# consistent — BaseFilters treats [] as "match zero" which differs from
# "no filter" (None).
source_types = source_types or None
document_set_names = document_set_names or None
# Parse time_cutoff string to datetime if provided
time_cutoff_dt: datetime | None = None
if time_cutoff:
@@ -142,6 +81,9 @@ async def search_indexed_documents(
# Continue with no time_cutoff instead of returning an error
time_cutoff_dt = None
# Initialize source_type_enums early to avoid UnboundLocalError
source_type_enums: list[DocumentSource] | None = None
# Get authenticated user from FastMCP's access token
access_token = require_access_token()
@@ -175,7 +117,6 @@ async def search_indexed_documents(
# Convert source_types strings to DocumentSource enums if provided
# Invalid values will be handled by the API server
source_type_enums: list[DocumentSource] | None = None
if source_types is not None:
source_type_enums = []
for src in source_types:
@@ -186,83 +127,83 @@ async def search_indexed_documents(
f"Onyx MCP Server: Invalid source type '{src}' - will be ignored by server"
)
filters: BaseFilters | None = None
if source_type_enums or document_set_names or time_cutoff_dt:
filters = BaseFilters(
source_type=source_type_enums,
document_set=document_set_names,
time_cutoff=time_cutoff_dt,
)
# Build filters dict only with non-None values
filters: dict[str, Any] | None = None
if source_type_enums or time_cutoff_dt:
filters = {}
if source_type_enums:
filters["source_type"] = [src.value for src in source_type_enums]
if time_cutoff_dt:
filters["time_cutoff"] = time_cutoff_dt.isoformat()
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
is_ee = global_version.is_ee_version()
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
request: BaseModel
search_request: dict[str, Any]
if is_ee:
# EE: use the dedicated search endpoint (no LLM invocation).
# Lazy import so CE deployments that strip ee/ never load this module.
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
request = SendSearchQueryRequest(
search_query=query,
filters=filters,
num_docs_fed_to_llm_selection=limit,
run_query_expansion=False,
include_content=True,
stream=False,
)
# EE: use the dedicated search endpoint (no LLM invocation)
search_request = {
"search_query": query,
"filters": filters,
"num_docs_fed_to_llm_selection": limit,
"run_query_expansion": False,
"include_content": True,
"stream": False,
}
endpoint = f"{base_url}/search/send-search-message"
error_key = "error"
docs_key = "search_docs"
content_field = "content"
else:
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
request = SendMessageRequest(
message=query,
stream=False,
chat_session_info=ChatSessionCreationRequest(),
internal_search_filters=filters,
)
search_request = {
"message": query,
"stream": False,
"chat_session_info": {},
}
if filters:
search_request["internal_search_filters"] = filters
endpoint = f"{base_url}/chat/send-chat-message"
error_key = "error_msg"
docs_key = "top_documents"
content_field = "blurb"
try:
response = await _post_model(
response = await get_http_client().post(
endpoint,
request,
access_token,
timeout=None if is_ee else _CE_SEARCH_TIMEOUT_SECONDS,
json=search_request,
headers=auth_headers,
)
if not response.is_success:
error_detail = _extract_error_detail(response)
return {
"documents": [],
"total_results": 0,
"query": query,
"error": _extract_error_detail(response),
"error": error_detail,
}
result = response.json()
# Check for error in response
if result.get(error_key):
return {
"documents": [],
"total_results": 0,
"query": query,
"error": result.get(error_key),
}
if is_ee:
from ee.onyx.server.query_and_chat.models import SearchFullResponse
ee_payload = SearchFullResponse.model_validate_json(response.content)
if ee_payload.error:
return {
"documents": [],
"total_results": 0,
"query": query,
"error": ee_payload.error,
}
documents = [
_project_doc(doc, doc.content) for doc in ee_payload.search_docs
]
else:
ce_payload = ChatFullResponse.model_validate_json(response.content)
if ce_payload.error_msg:
return {
"documents": [],
"total_results": 0,
"query": query,
"error": ce_payload.error_msg,
}
documents = [
_project_doc(doc, doc.blurb) for doc in ce_payload.top_documents
]
documents = [
{
"semantic_identifier": doc.get("semantic_identifier"),
"content": doc.get(content_field),
"source_type": doc.get("source_type"),
"link": doc.get("link"),
"score": doc.get("score"),
}
for doc in result.get(docs_key, [])
]
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
# `limit` only caps the returned list; fewer results may be returned if the
@@ -311,20 +252,23 @@ async def search_web(
access_token = require_access_token()
try:
response = await _post_model(
request_payload = {"queries": [query], "max_results": limit}
response = await get_http_client().post(
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/search-lite",
WebSearchToolRequest(queries=[query], max_results=limit),
access_token,
json=request_payload,
headers={"Authorization": f"Bearer {access_token.token}"},
)
if not response.is_success:
error_detail = _extract_error_detail(response)
return {
"error": _extract_error_detail(response),
"error": error_detail,
"results": [],
"query": query,
}
payload = WebSearchToolResponse.model_validate_json(response.content)
response_payload = response.json()
results = response_payload.get("results", [])
return {
"results": [result.model_dump(mode="json") for result in payload.results],
"results": results,
"query": query,
}
except Exception as e:
@@ -361,19 +305,21 @@ async def open_urls(
access_token = require_access_token()
try:
response = await _post_model(
response = await get_http_client().post(
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/open-urls",
OpenUrlsToolRequest(urls=urls),
access_token,
json={"urls": urls},
headers={"Authorization": f"Bearer {access_token.token}"},
)
if not response.is_success:
error_detail = _extract_error_detail(response)
return {
"error": _extract_error_detail(response),
"error": error_detail,
"results": [],
}
payload = OpenUrlsToolResponse.model_validate_json(response.content)
response_payload = response.json()
results = response_payload.get("results", [])
return {
"results": [result.model_dump(mode="json") for result in payload.results],
"results": results,
}
except Exception as e:
logger.error(f"Onyx MCP Server: URL fetch error: {e}", exc_info=True)

View File

@@ -5,24 +5,10 @@ from __future__ import annotations
import httpx
from fastmcp.server.auth.auth import AccessToken
from fastmcp.server.dependencies import get_access_token
from pydantic import BaseModel
from pydantic import TypeAdapter
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
class DocumentSetEntry(BaseModel):
"""Minimal document-set shape surfaced to MCP clients.
Projected from the backend's DocumentSetSummary to avoid coupling MCP to
admin-only fields (cc-pair summaries, federated connectors, etc.).
"""
name: str
description: str | None = None
logger = setup_logger()
# Shared HTTP client reused across requests
@@ -98,32 +84,3 @@ async def get_indexed_sources(
exc_info=True,
)
raise RuntimeError(f"Failed to fetch indexed sources: {exc}") from exc
_DOCUMENT_SET_ENTRIES_ADAPTER = TypeAdapter(list[DocumentSetEntry])
async def get_accessible_document_sets(
access_token: AccessToken,
) -> list[DocumentSetEntry]:
"""Fetch document sets accessible to the current user."""
headers = {"Authorization": f"Bearer {access_token.token}"}
try:
response = await get_http_client().get(
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/manage/document-set",
headers=headers,
)
response.raise_for_status()
return _DOCUMENT_SET_ENTRIES_ADAPTER.validate_json(response.content)
except (httpx.HTTPStatusError, httpx.RequestError, ValueError):
logger.error(
"Onyx MCP Server: Failed to fetch document sets",
exc_info=True,
)
raise
except Exception as exc:
logger.error(
"Onyx MCP Server: Unexpected error fetching document sets",
exc_info=True,
)
raise RuntimeError(f"Failed to fetch document sets: {exc}") from exc

View File

@@ -5,7 +5,6 @@ import uvicorn
from onyx.configs.app_configs import MCP_SERVER_ENABLED
from onyx.configs.app_configs import MCP_SERVER_HOST
from onyx.configs.app_configs import MCP_SERVER_PORT
from onyx.tracing.setup import setup_tracing
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
@@ -19,7 +18,6 @@ def main() -> None:
return
set_is_ee_based_on_env_variable()
setup_tracing()
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
from onyx.mcp_server.api import mcp_app

View File

@@ -11,8 +11,6 @@ All public functions no-op in single-tenant mode (`MULTI_TENANT=False`).
import time
from typing import cast
from prometheus_client import Counter
from prometheus_client import Gauge
from redis.client import Redis
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
@@ -28,40 +26,6 @@ logger = setup_logger()
_SET_KEY = "active_tenants"
# --- Prometheus metrics ---
_active_set_size = Gauge(
"onyx_tenant_work_gating_active_set_size",
"Current cardinality of the active_tenants sorted set (updated once per "
"generator invocation when the gate reads it).",
)
_marked_total = Counter(
"onyx_tenant_work_gating_marked_total",
"Writes into active_tenants, labelled by caller.",
["caller"],
)
_skipped_total = Counter(
"onyx_tenant_work_gating_skipped_total",
"Per-tenant fanouts skipped by the gate (enforce mode only), by task.",
["task"],
)
_would_skip_total = Counter(
"onyx_tenant_work_gating_would_skip_total",
"Per-tenant fanouts that would have been skipped if enforce were on "
"(shadow counter), by task.",
["task"],
)
_full_fanout_total = Counter(
"onyx_tenant_work_gating_full_fanout_total",
"Generator invocations that bypassed the gate for a full fanout cycle, by task.",
["task"],
)
def _now_ms() -> int:
return int(time.time() * 1000)
@@ -90,14 +54,10 @@ def mark_tenant_active(tenant_id: str) -> None:
logger.exception(f"mark_tenant_active failed: tenant_id={tenant_id}")
def maybe_mark_tenant_active(tenant_id: str, caller: str = "unknown") -> None:
def maybe_mark_tenant_active(tenant_id: str) -> None:
"""Convenience wrapper for writer call sites: records the tenant only
when the feature flag is on. Fully defensive — never raises, so a Redis
outage or flag-read failure can't abort the calling task.
`caller` labels the Prometheus counter so a dashboard can show which
consumer is firing the hook most.
"""
outage or flag-read failure can't abort the calling task."""
try:
# Local import to avoid a module-load cycle: OnyxRuntime imports
# onyx.redis.redis_pool, so a top-level import here would wedge on
@@ -107,44 +67,10 @@ def maybe_mark_tenant_active(tenant_id: str, caller: str = "unknown") -> None:
if not OnyxRuntime.get_tenant_work_gating_enabled():
return
mark_tenant_active(tenant_id)
_marked_total.labels(caller=caller).inc()
except Exception:
logger.exception(f"maybe_mark_tenant_active failed: tenant_id={tenant_id}")
def observe_active_set_size() -> int | None:
"""Return `ZCARD active_tenants` and update the Prometheus gauge. Call
from the gate generator once per invocation so the dashboard has a
live reading.
Returns `None` on Redis error or in single-tenant mode; callers can
tolerate that (gauge simply doesn't update)."""
if not MULTI_TENANT:
return None
try:
size = cast(int, _client().zcard(_SET_KEY))
_active_set_size.set(size)
return size
except Exception:
logger.exception("observe_active_set_size failed")
return None
def record_gate_decision(task_name: str, skipped: bool) -> None:
"""Increment skip counters from the gate generator. Called once per
tenant that the gate would skip. Always increments the shadow counter;
increments the enforced counter only when `skipped=True`."""
_would_skip_total.labels(task=task_name).inc()
if skipped:
_skipped_total.labels(task=task_name).inc()
def record_full_fanout_cycle(task_name: str) -> None:
"""Increment the full-fanout counter. Called once per generator
invocation where the gate is bypassed (interval elapsed OR fail-open)."""
_full_fanout_total.labels(task=task_name).inc()
def get_active_tenants(ttl_seconds: int) -> set[str] | None:
"""Return tenants whose last-seen timestamp is within `ttl_seconds` of
now.

View File

@@ -584,7 +584,7 @@ def associate_credential_to_connector(
# Tenant-work-gating lifecycle hook: keep new-tenant latency to
# seconds instead of one full-fanout interval.
maybe_mark_tenant_active(tenant_id, caller="cc_pair_lifecycle")
maybe_mark_tenant_active(tenant_id)
# trigger indexing immediately
client_app.send_task(

View File

@@ -1643,7 +1643,7 @@ def create_connector_with_mock_credential(
# Tenant-work-gating lifecycle hook: keep new-tenant latency to
# seconds instead of one full-fanout interval.
maybe_mark_tenant_active(tenant_id, caller="cc_pair_lifecycle")
maybe_mark_tenant_active(tenant_id)
# trigger indexing immediately
client_app.send_task(

View File

@@ -113,7 +113,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
# Tenant-work-gating hook: refresh this tenant's active-set
# membership whenever sandbox cleanup has work to do.
maybe_mark_tenant_active(tenant_id, caller="sandbox_cleanup")
maybe_mark_tenant_active(tenant_id)
task_logger.info(
f"Found {len(idle_sandboxes)} idle sandboxes to put to sleep"

View File

@@ -11,9 +11,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import count_docx_embedded_images
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -200,9 +198,6 @@ def categorize_uploaded_files(
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
batch_image_total = 0
# Hoisted out of the loop to avoid a KV-store lookup per file.
image_extraction_enabled = get_image_extraction_and_analysis_enabled()
for upload in files:
try:
filename = get_safe_filename(upload)
@@ -265,33 +260,28 @@ def categorize_uploaded_files(
)
continue
# Reject documents with an unreasonable number of embedded
# images (either per-file or accumulated across this upload
# batch). A file with thousands of embedded images can OOM the
# Reject PDFs with an unreasonable number of embedded images
# (either per-file or accumulated across this upload batch).
# A PDF with thousands of embedded images can OOM the
# user-file-processing celery worker because every image is
# decoded with PIL and then sent to the vision LLM.
count: int = 0
image_bearing_ext = extension in (".pdf", ".docx")
if image_bearing_ext:
if extension == ".pdf":
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
# Use the larger of the two caps as the short-circuit
# threshold so we get a useful count for both checks.
# These helpers restore the stream position.
counter = (
count_pdf_embedded_images
if extension == ".pdf"
else count_docx_embedded_images
# count_pdf_embedded_images restores the stream position.
count = count_pdf_embedded_images(
upload.file, max(file_cap, batch_cap)
)
count = counter(upload.file, max(file_cap, batch_cap))
if count > file_cap:
results.rejected.append(
RejectedFile(
filename=filename,
reason=(
f"Document contains too many embedded "
f"images (more than {file_cap}). Try "
f"splitting it into smaller files."
f"PDF contains too many embedded images "
f"(more than {file_cap}). Try splitting "
f"the document into smaller files."
),
)
)
@@ -318,21 +308,6 @@ def categorize_uploaded_files(
extension=extension,
)
if not text_content:
# Documents with embedded images (e.g. scans) have no
# extractable text but can still be indexed via the
# vision-LLM captioning path when image analysis is
# enabled.
if image_bearing_ext and count > 0 and image_extraction_enabled:
results.acceptable.append(upload)
results.acceptable_file_to_token_count[filename] = 0
try:
upload.file.seek(0)
except Exception as e:
logger.warning(
f"Failed to reset file pointer for '{filename}': {str(e)}"
)
continue
logger.warning(f"No text content extracted from '{filename}'")
results.rejected.append(
RejectedFile(

View File

@@ -3,7 +3,6 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
@@ -50,7 +49,6 @@ def get_opensearch_retrieval_status(
enable_opensearch_retrieval = get_opensearch_retrieval_state(db_session)
return OpenSearchRetrievalStatusResponse(
enable_opensearch_retrieval=enable_opensearch_retrieval,
toggling_retrieval_is_disabled=ONYX_DISABLE_VESPA,
)
@@ -65,5 +63,4 @@ def set_opensearch_retrieval_status(
)
return OpenSearchRetrievalStatusResponse(
enable_opensearch_retrieval=request.enable_opensearch_retrieval,
toggling_retrieval_is_disabled=ONYX_DISABLE_VESPA,
)

View File

@@ -19,4 +19,3 @@ class OpenSearchRetrievalStatusRequest(BaseModel):
class OpenSearchRetrievalStatusResponse(BaseModel):
model_config = {"frozen": True}
enable_opensearch_retrieval: bool
toggling_retrieval_is_disabled: bool = False

View File

@@ -1,151 +0,0 @@
"""Prometheus metrics for embedding generation latency and throughput.
Tracks client-side round-trip latency (as seen by callers of
``EmbeddingModel.encode``) and server-side execution time (as measured inside
the model server for the local-model path). Both API-provider and local-model
paths flow through the client-side metric; only the local path populates the
server-side metric.
"""
import logging
from collections.abc import Generator
from contextlib import contextmanager
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
logger = logging.getLogger(__name__)
LOCAL_PROVIDER_LABEL = "local"
_EMBEDDING_LATENCY_BUCKETS = (
0.005,
0.01,
0.025,
0.05,
0.1,
0.25,
0.5,
1.0,
2.5,
5.0,
10.0,
25.0,
)
PROVIDER_LABEL_NAME = "provider"
TEXT_TYPE_LABEL_NAME = "text_type"
STATUS_LABEL_NAME = "status"
_client_duration = Histogram(
"onyx_embedding_client_duration_seconds",
"Client-side end-to-end latency of an embedding batch as seen by the caller.",
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME],
buckets=_EMBEDDING_LATENCY_BUCKETS,
)
_embedding_requests_total = Counter(
"onyx_embedding_requests_total",
"Total embedding batch requests, labeled by outcome.",
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME, STATUS_LABEL_NAME],
)
_embedding_texts_total = Counter(
"onyx_embedding_texts_total",
"Total number of individual texts submitted for embedding.",
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME],
)
_embedding_input_chars_total = Counter(
"onyx_embedding_input_chars_total",
"Total number of input characters submitted for embedding.",
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME],
)
_embeddings_in_progress = Gauge(
"onyx_embeddings_in_progress",
"Number of embedding batches currently in-flight.",
[PROVIDER_LABEL_NAME, TEXT_TYPE_LABEL_NAME],
)
def provider_label(provider: EmbeddingProvider | None) -> str:
if provider is None:
return LOCAL_PROVIDER_LABEL
return provider.value
def observe_embedding_client(
provider: EmbeddingProvider | None,
text_type: EmbedTextType,
duration_s: float,
num_texts: int,
num_chars: int,
success: bool,
) -> None:
"""Records a completed embedding batch.
Args:
provider: The embedding provider, or ``None`` for the local model path.
text_type: Whether this was a query- or passage-style embedding.
duration_s: Wall-clock duration measured on the client side, in seconds.
num_texts: Number of texts in the batch.
num_chars: Total number of input characters in the batch.
success: Whether the embedding call succeeded.
"""
try:
provider_lbl = provider_label(provider)
text_type_lbl = text_type.value
status_lbl = "success" if success else "failure"
_embedding_requests_total.labels(
provider=provider_lbl, text_type=text_type_lbl, status=status_lbl
).inc()
_client_duration.labels(provider=provider_lbl, text_type=text_type_lbl).observe(
duration_s
)
if success:
_embedding_texts_total.labels(
provider=provider_lbl, text_type=text_type_lbl
).inc(num_texts)
_embedding_input_chars_total.labels(
provider=provider_lbl, text_type=text_type_lbl
).inc(num_chars)
except Exception:
logger.warning("Failed to record embedding client metrics.", exc_info=True)
@contextmanager
def track_embedding_in_progress(
provider: EmbeddingProvider | None,
text_type: EmbedTextType,
) -> Generator[None, None, None]:
"""Context manager that tracks in-flight embedding batches via a Gauge."""
incremented = False
provider_lbl = provider_label(provider)
text_type_lbl = text_type.value
try:
_embeddings_in_progress.labels(
provider=provider_lbl, text_type=text_type_lbl
).inc()
incremented = True
except Exception:
logger.warning(
"Failed to increment in-progress embedding gauge.", exc_info=True
)
try:
yield
finally:
if incremented:
try:
_embeddings_in_progress.labels(
provider=provider_lbl, text_type=text_type_lbl
).dec()
except Exception:
logger.warning(
"Failed to decrement in-progress embedding gauge.", exc_info=True
)

View File

@@ -395,15 +395,6 @@ class WorkerHealthCollector(_CachedCollector):
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
to the Celery event stream via a single persistent connection.
TODO: every monitoring pod subscribes to the cluster-wide Celery event
stream, so each replica reports health for *all* workers in the cluster,
not just itself. Prometheus distinguishes the replicas via the ``instance``
label, so this doesn't break scraping, but it means N monitoring replicas
do N× the work and may emit slightly inconsistent snapshots of the same
cluster. The proper fix is to have each worker expose its own health (or
to elect a single monitoring replica as the reporter) rather than
broadcasting the full cluster view from every monitoring pod.
"""
def __init__(self, cache_ttl: float = 30.0) -> None:
@@ -422,16 +413,10 @@ class WorkerHealthCollector(_CachedCollector):
"onyx_celery_active_worker_count",
"Number of active Celery workers with recent heartbeats",
)
# Celery hostnames are ``{worker_type}@{nodename}`` (see supervisord.conf).
# Emitting only the worker_type as a label causes N replicas of the same
# type to collapse into identical timeseries within a single scrape,
# which Prometheus rejects as "duplicate sample for timestamp". Split
# the pieces into separate labels so each replica is distinct; callers
# can still ``sum by (worker_type)`` to recover the old aggregated view.
worker_up = GaugeMetricFamily(
"onyx_celery_worker_up",
"Whether a specific Celery worker is alive (1=up, 0=down)",
labels=["worker_type", "hostname"],
labels=["worker"],
)
try:
@@ -439,15 +424,11 @@ class WorkerHealthCollector(_CachedCollector):
alive_count = sum(1 for alive in status.values() if alive)
active_workers.add_metric([], alive_count)
for full_hostname in sorted(status):
worker_type, sep, host = full_hostname.partition("@")
if not sep:
# Hostname didn't contain "@" — fall back to using the
# whole string as the hostname with an empty type.
worker_type, host = "", full_hostname
worker_up.add_metric(
[worker_type, host], 1 if status[full_hostname] else 0
)
for hostname in sorted(status):
# Use short name (before @) for single-host deployments,
# full hostname when multiple hosts share a worker type.
label = hostname.split("@")[0]
worker_up.add_metric([label], 1 if status[hostname] else 0)
except Exception:
logger.debug("Failed to collect worker health metrics", exc_info=True)

View File

@@ -63,7 +63,6 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.db.user_file import user_can_access_chat_file
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
@@ -867,18 +866,14 @@ def seed_chat_from_slack(
def fetch_chat_file(
file_id: str,
request: Request,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> Response:
# For user files, we need to get the file id from the user file id
file_id_from_user_file = get_file_id_by_user_file_id(file_id, user.id, db_session)
file_id_from_user_file = get_file_id_by_user_file_id(file_id, db_session)
if file_id_from_user_file:
file_id = file_id_from_user_file
elif not user_can_access_chat_file(file_id, user.id, db_session):
# Return 404 (rather than 403) so callers cannot probe for file
# existence across ownership boundaries.
raise OnyxError(OnyxErrorCode.NOT_FOUND, "File not found")
file_store = get_default_file_store()
file_record = file_store.read_file_record(file_id)

View File

@@ -80,7 +80,7 @@ class Settings(BaseModel):
query_history_type: QueryHistoryType | None = None
# Image processing settings
image_extraction_and_analysis_enabled: bool | None = True
image_extraction_and_analysis_enabled: bool | None = False
search_time_image_analysis_enabled: bool | None = False
image_analysis_max_size_mb: int | None = 20

View File

@@ -7,7 +7,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.embedding_configs import SUPPORTED_EMBEDDING_MODELS
@@ -127,11 +126,10 @@ def setup_onyx(
"DISABLE_VECTOR_DB is set — skipping document index setup and embedding model warm-up."
)
else:
# Ensure the document indices are setup correctly. This step is
# relatively near the end because Vespa takes a bit of time to start up.
# Ensure Vespa is setup correctly, this step is relatively near the end
# because Vespa takes a bit of time to start up
logger.notice("Verifying Document Index(s) is/are available.")
# This flow is for setting up the document index so we get all indices
# here.
# This flow is for setting up the document index so we get all indices here.
document_indices = get_all_document_indices(
search_settings,
secondary_search_settings,
@@ -337,7 +335,7 @@ def setup_multitenant_onyx() -> None:
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
# NOTE: Pretty sure this code is never hit in any production environment.
if not MANAGED_VESPA and not ONYX_DISABLE_VESPA:
if not MANAGED_VESPA:
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)

View File

@@ -46,7 +46,7 @@ stop_and_remove_containers
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
if [[ -n "$POSTGRES_VOLUME" ]]; then
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d -v "$POSTGRES_VOLUME":/var/lib/postgresql/data postgres -c max_connections=250
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres -c max_connections=250
else
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d postgres -c max_connections=250
fi
@@ -54,7 +54,7 @@ fi
# Start the Vespa container with optional volume
echo "Starting Vespa container..."
if [[ -n "$VESPA_VOLUME" ]]; then
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v "$VESPA_VOLUME":/opt/vespa/var vespaengine/vespa:8
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v $VESPA_VOLUME:/opt/vespa/var vespaengine/vespa:8
else
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
fi
@@ -85,7 +85,7 @@ docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-en
# Start the Redis container with optional volume
echo "Starting Redis container..."
if [[ -n "$REDIS_VOLUME" ]]; then
docker run --detach --name onyx_redis --publish 6379:6379 -v "$REDIS_VOLUME":/data redis
docker run --detach --name onyx_redis --publish 6379:6379 -v $REDIS_VOLUME:/data redis
else
docker run --detach --name onyx_redis --publish 6379:6379 redis
fi
@@ -93,7 +93,7 @@ fi
# Start the MinIO container with optional volume
echo "Starting MinIO container..."
if [[ -n "$MINIO_VOLUME" ]]; then
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin -v "$MINIO_VOLUME":/data minio/minio server /data --console-address ":9001"
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin -v $MINIO_VOLUME:/data minio/minio server /data --console-address ":9001"
else
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
fi
@@ -111,7 +111,6 @@ sleep 1
# Alembic should be configured in the virtualenv for this repo
if [[ -f "../.venv/bin/activate" ]]; then
# shellcheck source=/dev/null
source ../.venv/bin/activate
else
echo "Warning: Python virtual environment not found at .venv/bin/activate; alembic may not work."

View File

@@ -99,14 +99,6 @@ STRICT_CHUNK_TOKEN_LIMIT = (
# Set up Sentry integration (for error logging)
SENTRY_DSN = os.environ.get("SENTRY_DSN")
# Celery task spans dominate ingestion volume (~94%), so default celery
# tracing to 0. Web/API traces stay at a small non-zero rate so http.server
# traces remain available. Both are env-tunable without a code change.
SENTRY_TRACES_SAMPLE_RATE = float(os.environ.get("SENTRY_TRACES_SAMPLE_RATE", "0.01"))
SENTRY_CELERY_TRACES_SAMPLE_RATE = float(
os.environ.get("SENTRY_CELERY_TRACES_SAMPLE_RATE", "0.0")
)
# Fields which should only be set on new search setting
PRESERVED_SEARCH_FIELDS = [

View File

@@ -9,10 +9,8 @@ import pytest
from onyx.configs.constants import BlobType
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
@@ -113,18 +111,15 @@ def test_blob_s3_connector(
for doc in all_docs:
section = doc.sections[0]
if is_tabular_file(doc.semantic_identifier):
assert isinstance(section, TabularSection)
assert len(section.text) > 0
continue
assert isinstance(section, TextSection)
file_extension = get_file_ext(doc.semantic_identifier)
if file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
assert len(section.text) > 0
else:
assert len(section.text) == 0
continue
# unknown extension
assert len(section.text) == 0
@patch(

View File

@@ -1,130 +0,0 @@
"""External dependency tests for onyx.file_store.staging.
Exercises the raw-file persistence hook used by the docfetching pipeline
against a real file store (Postgres + MinIO/S3), since mocking the store
would defeat the point of verifying that metadata round-trips through
FileRecord.
"""
from collections.abc import Generator
from io import BytesIO
from typing import Any
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.interfaces import BaseConnector
from onyx.db.file_record import delete_filerecord_by_file_id
from onyx.db.file_record import get_filerecord_by_file_id
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.staging import build_raw_file_callback
from onyx.file_store.staging import stage_raw_file
@pytest.fixture(scope="function")
def cleanup_file_ids(
db_session: Session,
) -> Generator[list[str], None, None]:
created: list[str] = []
yield created
file_store = get_default_file_store()
for fid in created:
try:
file_store.delete_file(fid)
except Exception:
delete_filerecord_by_file_id(file_id=fid, db_session=db_session)
db_session.commit()
def test_stage_raw_file_persists_with_origin_and_metadata(
db_session: Session,
tenant_context: None, # noqa: ARG001
initialize_file_store: None, # noqa: ARG001
cleanup_file_ids: list[str],
) -> None:
"""stage_raw_file writes a FileRecord with INDEXING_STAGING origin and
round-trips the provided metadata verbatim."""
metadata: dict[str, Any] = {
"index_attempt_id": 42,
"cc_pair_id": 7,
"tenant_id": "tenant-abc",
"extra": "payload",
}
content_bytes = b"hello raw file"
content_type = "application/pdf"
file_id = stage_raw_file(
content=BytesIO(content_bytes),
content_type=content_type,
metadata=metadata,
)
cleanup_file_ids.append(file_id)
db_session.commit()
record = get_filerecord_by_file_id(file_id=file_id, db_session=db_session)
assert record.file_origin == FileOrigin.INDEXING_STAGING
assert record.file_type == content_type
assert record.file_metadata == metadata
def test_build_raw_file_callback_binds_attempt_context_per_call(
db_session: Session,
tenant_context: None, # noqa: ARG001
initialize_file_store: None, # noqa: ARG001
cleanup_file_ids: list[str],
) -> None:
"""The callback returned by build_raw_file_callback must bind the
attempt-level context into every FileRecord it produces, without
leaking state across invocations."""
callback = build_raw_file_callback(
index_attempt_id=1001,
cc_pair_id=202,
tenant_id="tenant-xyz",
)
file_id_a = callback(BytesIO(b"alpha"), "text/plain")
file_id_b = callback(BytesIO(b"beta"), "application/octet-stream")
cleanup_file_ids.extend([file_id_a, file_id_b])
db_session.commit()
assert file_id_a != file_id_b
for fid, expected_content_type in (
(file_id_a, "text/plain"),
(file_id_b, "application/octet-stream"),
):
record = get_filerecord_by_file_id(file_id=fid, db_session=db_session)
assert record.file_origin == FileOrigin.INDEXING_STAGING
assert record.file_type == expected_content_type
assert record.file_metadata == {
"index_attempt_id": 1001,
"cc_pair_id": 202,
"tenant_id": "tenant-xyz",
}
def test_set_raw_file_callback_on_base_connector() -> None:
"""set_raw_file_callback must install the callback as an instance
attribute usable by the connector."""
class _MinimalConnector(BaseConnector):
def load_credentials(
self,
credentials: dict[str, Any], # noqa: ARG002
) -> dict[str, Any] | None:
return None
connector = _MinimalConnector()
assert connector.raw_file_callback is None
sentinel_file_id = f"sentinel-{uuid4().hex[:8]}"
def _fake_callback(_content: Any, _content_type: str) -> str:
return sentinel_file_id
connector.set_raw_file_callback(_fake_callback)
assert connector.raw_file_callback is _fake_callback
assert connector.raw_file_callback(BytesIO(b""), "text/plain") == sentinel_file_id

View File

@@ -1,405 +0,0 @@
"""External dependency unit tests for `index_doc_batch_prepare`.
Validates the file_id lifecycle that runs alongside the document upsert:
* `document.file_id` is written on insert AND on conflict (upsert path)
* Newly-staged files get promoted from INDEXING_STAGING -> CONNECTOR
* Replaced files are deleted from both `file_record` and S3
* No-op when the file_id is unchanged
Uses real PostgreSQL + real S3/MinIO via the file store.
"""
from collections.abc import Generator
from io import BytesIO
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import InputType
from onyx.connectors.models import TextSection
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.file_record import get_filerecord_by_file_id_optional
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Document as DBDocument
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import FileRecord
from onyx.file_store.file_store import get_default_file_store
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_doc(doc_id: str, file_id: str | None = None) -> Document:
"""Minimal Document for indexing-pipeline tests. MOCK_CONNECTOR avoids
triggering the hierarchy-node linking branch (NOTION/CONFLUENCE only)."""
return Document(
id=doc_id,
source=DocumentSource.MOCK_CONNECTOR,
semantic_identifier=f"semantic-{doc_id}",
sections=[TextSection(text="content", link=None)],
metadata={},
file_id=file_id,
)
def _stage_file(content: bytes = b"raw bytes") -> str:
"""Write bytes to the file store as INDEXING_STAGING and return the file_id.
Mirrors what the connector raw_file_callback would do during fetch.
"""
return get_default_file_store().save_file(
content=BytesIO(content),
display_name=None,
file_origin=FileOrigin.INDEXING_STAGING,
file_type="application/octet-stream",
file_metadata={"test": True},
)
def _get_doc_row(db_session: Session, doc_id: str) -> DBDocument | None:
"""Reload the document row fresh from DB so we see post-upsert state."""
db_session.expire_all()
return db_session.query(DBDocument).filter(DBDocument.id == doc_id).one_or_none()
def _get_filerecord(db_session: Session, file_id: str) -> FileRecord | None:
db_session.expire_all()
return get_filerecord_by_file_id_optional(file_id=file_id, db_session=db_session)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def cc_pair(
db_session: Session,
tenant_context: None, # noqa: ARG001
initialize_file_store: None, # noqa: ARG001
) -> Generator[ConnectorCredentialPair, None, None]:
"""Create a connector + credential + cc_pair backing the index attempt.
Teardown sweeps everything the test created under this cc_pair: the
`document_by_connector_credential_pair` join rows, the `Document` rows
they point at, the `FileRecord` + blob for each doc's `file_id`, and
finally the cc_pair / connector / credential themselves. Without this,
every run would leave orphan rows in the dev DB and orphan blobs in
MinIO.
"""
connector = Connector(
name=f"test-connector-{uuid4().hex[:8]}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.LOAD_STATE,
connector_specific_config={},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
db_session.add(connector)
db_session.flush()
credential = Credential(
source=DocumentSource.MOCK_CONNECTOR,
credential_json={},
)
db_session.add(credential)
db_session.flush()
pair = ConnectorCredentialPair(
connector_id=connector.id,
credential_id=credential.id,
name=f"test-cc-pair-{uuid4().hex[:8]}",
status=ConnectorCredentialPairStatus.ACTIVE,
access_type=AccessType.PUBLIC,
auto_sync_options=None,
)
db_session.add(pair)
db_session.commit()
db_session.refresh(pair)
connector_id = pair.connector_id
credential_id = pair.credential_id
try:
yield pair
finally:
db_session.expire_all()
# Collect every doc indexed under this cc_pair so we can delete its
# file_record + blob before dropping the Document row itself.
doc_ids: list[str] = [
row[0]
for row in db_session.query(DocumentByConnectorCredentialPair.id)
.filter(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
.all()
]
file_ids: list[str] = [
row[0]
for row in db_session.query(DBDocument.file_id)
.filter(DBDocument.id.in_(doc_ids), DBDocument.file_id.isnot(None))
.all()
]
file_store = get_default_file_store()
for fid in file_ids:
try:
file_store.delete_file(fid, error_on_missing=False)
except Exception:
pass
if doc_ids:
db_session.query(DocumentByConnectorCredentialPair).filter(
DocumentByConnectorCredentialPair.id.in_(doc_ids)
).delete(synchronize_session="fetch")
db_session.query(DBDocument).filter(DBDocument.id.in_(doc_ids)).delete(
synchronize_session="fetch"
)
db_session.query(ConnectorCredentialPair).filter(
ConnectorCredentialPair.id == pair.id
).delete(synchronize_session="fetch")
db_session.query(Connector).filter(Connector.id == connector_id).delete(
synchronize_session="fetch"
)
db_session.query(Credential).filter(Credential.id == credential_id).delete(
synchronize_session="fetch"
)
db_session.commit()
@pytest.fixture
def attempt_metadata(cc_pair: ConnectorCredentialPair) -> IndexAttemptMetadata:
return IndexAttemptMetadata(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
attempt_id=None,
request_id="test-request",
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestNewDocuments:
"""First-time inserts — no previous file_id to reconcile against."""
def test_new_doc_without_file_id(
self,
db_session: Session,
attempt_metadata: IndexAttemptMetadata,
) -> None:
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=None)
index_doc_batch_prepare(
documents=[doc],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
row = _get_doc_row(db_session, doc.id)
assert row is not None
assert row.file_id is None
def test_new_doc_with_staged_file_id_promotes_to_connector(
self,
db_session: Session,
attempt_metadata: IndexAttemptMetadata,
) -> None:
file_id = _stage_file()
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=file_id)
index_doc_batch_prepare(
documents=[doc],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
row = _get_doc_row(db_session, doc.id)
assert row is not None and row.file_id == file_id
record = _get_filerecord(db_session, file_id)
assert record is not None
assert record.file_origin == FileOrigin.CONNECTOR
class TestExistingDocuments:
"""Re-index path — a `document` row already exists with some file_id."""
def test_unchanged_file_id_is_noop(
self,
db_session: Session,
attempt_metadata: IndexAttemptMetadata,
) -> None:
file_id = _stage_file()
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=file_id)
# First pass: inserts the row + promotes the file.
index_doc_batch_prepare(
documents=[doc],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
# Second pass with the same file_id — should not delete or re-promote.
index_doc_batch_prepare(
documents=[doc],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
record = _get_filerecord(db_session, file_id)
assert record is not None
assert record.file_origin == FileOrigin.CONNECTOR
row = _get_doc_row(db_session, doc.id)
assert row is not None and row.file_id == file_id
def test_swapping_file_id_promotes_new_and_deletes_old(
self,
db_session: Session,
attempt_metadata: IndexAttemptMetadata,
) -> None:
old_file_id = _stage_file(content=b"old bytes")
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=old_file_id)
index_doc_batch_prepare(
documents=[doc],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
# Re-fetch produces a new staged file_id for the same doc.
new_file_id = _stage_file(content=b"new bytes")
doc_v2 = _make_doc(doc.id, file_id=new_file_id)
index_doc_batch_prepare(
documents=[doc_v2],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
row = _get_doc_row(db_session, doc.id)
assert row is not None and row.file_id == new_file_id
new_record = _get_filerecord(db_session, new_file_id)
assert new_record is not None
assert new_record.file_origin == FileOrigin.CONNECTOR
# Old file_record + S3 object are gone.
assert _get_filerecord(db_session, old_file_id) is None
def test_clearing_file_id_deletes_old_and_nulls_column(
self,
db_session: Session,
attempt_metadata: IndexAttemptMetadata,
) -> None:
old_file_id = _stage_file()
doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=old_file_id)
index_doc_batch_prepare(
documents=[doc],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
# Connector opts out on next run — yields the doc without a file_id.
doc_v2 = _make_doc(doc.id, file_id=None)
index_doc_batch_prepare(
documents=[doc_v2],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
row = _get_doc_row(db_session, doc.id)
assert row is not None and row.file_id is None
assert _get_filerecord(db_session, old_file_id) is None
class TestBatchHandling:
"""Mixed batches — multiple docs at different lifecycle states in one call."""
def test_mixed_batch_each_doc_handled_independently(
self,
db_session: Session,
attempt_metadata: IndexAttemptMetadata,
) -> None:
# Pre-seed an existing doc with a file_id we'll swap.
existing_old_id = _stage_file(content=b"existing-old")
existing_doc = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=existing_old_id)
index_doc_batch_prepare(
documents=[existing_doc],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
# Now: swap the existing one, add a brand-new doc with file_id, and a
# brand-new doc without file_id.
swap_new_id = _stage_file(content=b"existing-new")
new_with_file_id = _stage_file(content=b"new-with-file")
existing_v2 = _make_doc(existing_doc.id, file_id=swap_new_id)
new_with = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=new_with_file_id)
new_without = _make_doc(f"doc-{uuid4().hex[:8]}", file_id=None)
index_doc_batch_prepare(
documents=[existing_v2, new_with, new_without],
index_attempt_metadata=attempt_metadata,
db_session=db_session,
ignore_time_skip=True,
)
db_session.commit()
# Existing doc was swapped: old file gone, new file promoted.
existing_row = _get_doc_row(db_session, existing_doc.id)
assert existing_row is not None and existing_row.file_id == swap_new_id
assert _get_filerecord(db_session, existing_old_id) is None
swap_record = _get_filerecord(db_session, swap_new_id)
assert swap_record is not None
assert swap_record.file_origin == FileOrigin.CONNECTOR
# New doc with file_id: row exists, file promoted.
new_with_row = _get_doc_row(db_session, new_with.id)
assert new_with_row is not None and new_with_row.file_id == new_with_file_id
new_with_record = _get_filerecord(db_session, new_with_file_id)
assert new_with_record is not None
assert new_with_record.file_origin == FileOrigin.CONNECTOR
# New doc without file_id: row exists, no file_record involvement.
new_without_row = _get_doc_row(db_session, new_without.id)
assert new_without_row is not None and new_without_row.file_id is None

View File

@@ -446,107 +446,10 @@ class TestOpenSearchClient:
test_client.create_index(mappings=mappings, settings=settings)
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
"""Tests updating index settings on an existing index."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Assert that the current number of replicas is not the desired test
# number we are updating to.
test_num_replicas = 0
current_settings = test_client.get_settings()
assert current_settings["index"]["number_of_replicas"] != f"{test_num_replicas}"
# Under test.
# Should not raise. number_of_replicas is a dynamic setting that can be
# changed without closing the index.
test_client.update_settings(
settings={"index": {"number_of_replicas": test_num_replicas}}
)
# Postcondition.
current_settings = test_client.get_settings()
assert current_settings["index"]["number_of_replicas"] == f"{test_num_replicas}"
def test_update_settings_on_nonexistent_index(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests updating settings on a nonexistent index raises an error."""
"""Tests that update_settings raises NotImplementedError."""
# Under test and postcondition.
with pytest.raises(Exception, match="index_not_found_exception|404"):
test_client.update_settings(settings={"index": {"number_of_replicas": 0}})
def test_get_settings(self, test_client: OpenSearchIndexClient) -> None:
"""Tests getting index settings."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
current_settings = test_client.get_settings()
# Postcondition.
assert "index" in current_settings
# These are always present for any index.
assert "number_of_shards" in current_settings["index"]
assert "number_of_replicas" in current_settings["index"]
assert current_settings["index"]["provided_name"] == test_client._index_name
def test_get_settings_on_nonexistent_index(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests getting settings on a nonexistent index raises an error."""
# Under test and postcondition.
with pytest.raises(Exception, match="index_not_found_exception|404"):
test_client.get_settings()
def test_close_and_open_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests closing and reopening an index."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings_based_on_environment()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
# Closing should not raise.
test_client.close_index()
# Postcondition.
# Searches on a closed index should fail.
with pytest.raises(Exception, match="index_closed_exception|closed"):
test_client.search_for_document_ids(
body={"_source": False, "query": {"match_all": {}}}
)
# Under test.
# Reopening should not raise.
test_client.open_index()
# Postcondition.
# Searches should work again after reopening.
result = test_client.search_for_document_ids(
body={"_source": False, "query": {"match_all": {}}}
)
assert result == []
def test_close_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests closing a nonexistent index raises an error."""
# Under test and postcondition.
with pytest.raises(Exception, match="index_not_found_exception|404"):
test_client.close_index()
def test_open_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests opening a nonexistent index raises an error."""
# Under test and postcondition.
with pytest.raises(Exception, match="index_not_found_exception|404"):
test_client.open_index()
with pytest.raises(NotImplementedError):
test_client.update_settings(settings={})
def test_create_and_delete_search_pipeline(
self, test_client: OpenSearchIndexClient

View File

@@ -1,210 +0,0 @@
"""Tests for `cloud_beat_task_generator`'s tenant work-gating logic.
Exercises the gate-read path end-to-end against real Redis. The Celery
`.app.send_task` is mocked so we can count dispatches without actually
sending messages.
Requires a running Redis instance. Run with::
python -m dotenv -f .vscode/.env run -- pytest \
backend/tests/external_dependency_unit/tenant_work_gating/test_gate_generator.py
"""
from collections.abc import Generator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from ee.onyx.background.celery.tasks.cloud import tasks as cloud_tasks
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.redis import redis_tenant_work_gating as twg
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_tenant_work_gating import _SET_KEY
from onyx.redis.redis_tenant_work_gating import mark_tenant_active
_TENANT_A = "tenant_aaaa0000-0000-0000-0000-000000000001"
_TENANT_B = "tenant_bbbb0000-0000-0000-0000-000000000002"
_TENANT_C = "tenant_cccc0000-0000-0000-0000-000000000003"
_ALL_TEST_TENANTS = [_TENANT_A, _TENANT_B, _TENANT_C]
_FANOUT_KEY_PREFIX = cloud_tasks._FULL_FANOUT_TIMESTAMP_KEY_PREFIX
@pytest.fixture(autouse=True)
def _multi_tenant_true() -> Generator[None, None, None]:
with patch.object(twg, "MULTI_TENANT", True):
yield
@pytest.fixture(autouse=True)
def _clean_redis() -> Generator[None, None, None]:
"""Clear the active set AND the per-task full-fanout timestamp so each
test starts fresh."""
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
r.delete(_SET_KEY)
r.delete(f"{_FANOUT_KEY_PREFIX}:test_task")
r.delete("runtime:tenant_work_gating:enabled")
r.delete("runtime:tenant_work_gating:enforce")
yield
r.delete(_SET_KEY)
r.delete(f"{_FANOUT_KEY_PREFIX}:test_task")
r.delete("runtime:tenant_work_gating:enabled")
r.delete("runtime:tenant_work_gating:enforce")
def _invoke_generator(
*,
work_gated: bool,
enabled: bool,
enforce: bool,
tenant_ids: list[str],
full_fanout_interval_seconds: int = 1200,
ttl_seconds: int = 1800,
) -> MagicMock:
"""Helper: call the generator with runtime flags fixed and the Celery
app mocked. Returns the mock so callers can assert on send_task calls."""
mock_app = MagicMock()
# The task binds `self` = the task itself when invoked via `.run()`;
# patch its `.app` so `self.app.send_task` routes to our mock.
with (
patch.object(cloud_tasks.cloud_beat_task_generator, "app", mock_app),
patch.object(cloud_tasks, "get_all_tenant_ids", return_value=list(tenant_ids)),
patch.object(cloud_tasks, "get_gated_tenants", return_value=set()),
patch(
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_enabled",
return_value=enabled,
),
patch(
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_enforce",
return_value=enforce,
),
patch(
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_full_fanout_interval_seconds",
return_value=full_fanout_interval_seconds,
),
patch(
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_ttl_seconds",
return_value=ttl_seconds,
),
):
cloud_tasks.cloud_beat_task_generator.run(
task_name="test_task",
work_gated=work_gated,
)
return mock_app
def _dispatched_tenants(mock_app: MagicMock) -> list[str]:
"""Pull tenant_ids out of each send_task call for assertion."""
return [c.kwargs["kwargs"]["tenant_id"] for c in mock_app.send_task.call_args_list]
def _seed_recent_full_fanout_timestamp() -> None:
"""Pre-seed the per-task timestamp so the interval-elapsed branch
reports False, i.e. the gate enforces normally instead of going into
full-fanout on first invocation."""
import time as _t
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
r.set(f"{_FANOUT_KEY_PREFIX}:test_task", str(int(_t.time() * 1000)))
def test_enforce_skips_unmarked_tenants() -> None:
"""With enable+enforce on (interval NOT elapsed), only tenants in the
active set get dispatched."""
mark_tenant_active(_TENANT_A)
_seed_recent_full_fanout_timestamp()
mock_app = _invoke_generator(
work_gated=True,
enabled=True,
enforce=True,
tenant_ids=_ALL_TEST_TENANTS,
full_fanout_interval_seconds=3600,
)
dispatched = _dispatched_tenants(mock_app)
assert dispatched == [_TENANT_A]
def test_shadow_mode_dispatches_all_tenants() -> None:
"""enabled=True, enforce=False: gate computes skip but still dispatches."""
mark_tenant_active(_TENANT_A)
_seed_recent_full_fanout_timestamp()
mock_app = _invoke_generator(
work_gated=True,
enabled=True,
enforce=False,
tenant_ids=_ALL_TEST_TENANTS,
full_fanout_interval_seconds=3600,
)
dispatched = _dispatched_tenants(mock_app)
assert set(dispatched) == set(_ALL_TEST_TENANTS)
def test_full_fanout_cycle_dispatches_all_tenants() -> None:
"""First invocation (no prior timestamp → interval considered elapsed)
counts as full-fanout; every tenant gets dispatched even under enforce."""
mark_tenant_active(_TENANT_A)
mock_app = _invoke_generator(
work_gated=True,
enabled=True,
enforce=True,
tenant_ids=_ALL_TEST_TENANTS,
)
dispatched = _dispatched_tenants(mock_app)
assert set(dispatched) == set(_ALL_TEST_TENANTS)
def test_redis_unavailable_fails_open() -> None:
"""When `get_active_tenants` returns None (simulated Redis outage) the
gate treats the invocation as full-fanout and dispatches everyone —
even when the interval hasn't elapsed and enforce is on."""
mark_tenant_active(_TENANT_A)
_seed_recent_full_fanout_timestamp()
with patch.object(cloud_tasks, "get_active_tenants", return_value=None):
mock_app = _invoke_generator(
work_gated=True,
enabled=True,
enforce=True,
tenant_ids=_ALL_TEST_TENANTS,
full_fanout_interval_seconds=3600,
)
dispatched = _dispatched_tenants(mock_app)
assert set(dispatched) == set(_ALL_TEST_TENANTS)
def test_work_gated_false_bypasses_gate_entirely() -> None:
"""Beat templates that don't opt in (`work_gated=False`) never consult
the set — no matter the flag state."""
# Even with enforce on and nothing in the set, all tenants dispatch.
mock_app = _invoke_generator(
work_gated=False,
enabled=True,
enforce=True,
tenant_ids=_ALL_TEST_TENANTS,
)
dispatched = _dispatched_tenants(mock_app)
assert set(dispatched) == set(_ALL_TEST_TENANTS)
def test_gate_disabled_dispatches_everyone_regardless_of_enforce() -> None:
"""enabled=False means the gate isn't computed — dispatch is unchanged."""
# Intentionally don't add anyone to the set.
mock_app = _invoke_generator(
work_gated=True,
enabled=False,
enforce=True,
tenant_ids=_ALL_TEST_TENANTS,
)
dispatched = _dispatched_tenants(mock_app)
assert set(dispatched) == set(_ALL_TEST_TENANTS)

View File

@@ -247,7 +247,7 @@ class DATestSettings(BaseModel):
gpu_enabled: bool | None = None
product_gating: DATestGatingType = DATestGatingType.NONE
anonymous_user_enabled: bool | None = None
image_extraction_and_analysis_enabled: bool | None = True
image_extraction_and_analysis_enabled: bool | None = False
search_time_image_analysis_enabled: bool | None = False

View File

@@ -16,14 +16,12 @@ from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult
from mcp.types import TextContent
from pydantic import AnyUrl
from onyx.db.enums import AccessType
from tests.integration.common_utils.constants import MCP_SERVER_URL
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.document_set import DocumentSetManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.pat import PATManager
from tests.integration.common_utils.managers.user import UserManager
@@ -36,7 +34,6 @@ from tests.integration.common_utils.test_models import DATestUser
# Constants
MCP_SEARCH_TOOL = "search_indexed_documents"
INDEXED_SOURCES_RESOURCE_URI = "resource://indexed_sources"
DOCUMENT_SETS_RESOURCE_URI = "resource://document_sets"
DEFAULT_SEARCH_LIMIT = 5
STREAMABLE_HTTP_URL = f"{MCP_SERVER_URL.rstrip('/')}/?transportType=streamable-http"
@@ -76,22 +73,19 @@ def _extract_tool_payload(result: CallToolResult) -> dict[str, Any]:
def _call_search_tool(
headers: dict[str, str],
query: str,
limit: int = DEFAULT_SEARCH_LIMIT,
document_set_names: list[str] | None = None,
headers: dict[str, str], query: str, limit: int = DEFAULT_SEARCH_LIMIT
) -> CallToolResult:
"""Call the search_indexed_documents tool via MCP."""
async def _action(session: ClientSession) -> CallToolResult:
await session.initialize()
arguments: dict[str, Any] = {
"query": query,
"limit": limit,
}
if document_set_names is not None:
arguments["document_set_names"] = document_set_names
return await session.call_tool(MCP_SEARCH_TOOL, arguments)
return await session.call_tool(
MCP_SEARCH_TOOL,
{
"query": query,
"limit": limit,
},
)
return _run_with_mcp_session(headers, _action)
@@ -244,106 +238,3 @@ def test_mcp_search_respects_acl_filters(
blocked_payload = _extract_tool_payload(blocked_result)
assert blocked_payload["total_results"] == 0
assert blocked_payload["documents"] == []
def test_mcp_search_filters_by_document_set(
reset: None, # noqa: ARG001
admin_user: DATestUser,
) -> None:
"""Passing document_set_names should scope results to the named set."""
LLMProviderManager.create(user_performing_action=admin_user)
api_key = APIKeyManager.create(user_performing_action=admin_user)
cc_pair_in_set = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
)
cc_pair_out_of_set = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
)
shared_phrase = "document-set-filter-shared-phrase"
in_set_content = f"{shared_phrase} inside curated set"
out_of_set_content = f"{shared_phrase} outside curated set"
_seed_document_and_wait_for_indexing(
cc_pair=cc_pair_in_set,
content=in_set_content,
api_key=api_key,
user_performing_action=admin_user,
)
_seed_document_and_wait_for_indexing(
cc_pair=cc_pair_out_of_set,
content=out_of_set_content,
api_key=api_key,
user_performing_action=admin_user,
)
doc_set = DocumentSetManager.create(
cc_pair_ids=[cc_pair_in_set.id],
user_performing_action=admin_user,
)
DocumentSetManager.wait_for_sync(
user_performing_action=admin_user,
document_sets_to_check=[doc_set],
)
headers = _auth_headers(admin_user, name="mcp-doc-set-filter")
# The document_sets resource should surface the newly created set so MCP
# clients can discover which values to pass to document_set_names.
async def _list_resources(session: ClientSession) -> Any:
await session.initialize()
resources = await session.list_resources()
contents = await session.read_resource(AnyUrl(DOCUMENT_SETS_RESOURCE_URI))
return resources, contents
resources_result, doc_sets_contents = _run_with_mcp_session(
headers, _list_resources
)
resource_uris = {str(resource.uri) for resource in resources_result.resources}
assert DOCUMENT_SETS_RESOURCE_URI in resource_uris
doc_sets_payload = json.loads(doc_sets_contents.contents[0].text)
exposed_names = {entry["name"] for entry in doc_sets_payload}
assert doc_set.name in exposed_names
# Without the filter both documents are visible.
unfiltered_payload = _extract_tool_payload(
_call_search_tool(headers, shared_phrase, limit=10)
)
unfiltered_contents = [
doc.get("content") or "" for doc in unfiltered_payload["documents"]
]
assert any(in_set_content in content for content in unfiltered_contents)
assert any(out_of_set_content in content for content in unfiltered_contents)
# With the document set filter only the in-set document is returned.
filtered_payload = _extract_tool_payload(
_call_search_tool(
headers,
shared_phrase,
limit=10,
document_set_names=[doc_set.name],
)
)
filtered_contents = [
doc.get("content") or "" for doc in filtered_payload["documents"]
]
assert filtered_payload["total_results"] >= 1
assert any(in_set_content in content for content in filtered_contents)
assert all(out_of_set_content not in content for content in filtered_contents)
# An empty document_set_names should behave like "no filter" (normalized
# to None), not "match zero sets".
empty_list_payload = _extract_tool_payload(
_call_search_tool(
headers,
shared_phrase,
limit=10,
document_set_names=[],
)
)
empty_list_contents = [
doc.get("content") or "" for doc in empty_list_payload["documents"]
]
assert any(in_set_content in content for content in empty_list_contents)
assert any(out_of_set_content in content for content in empty_list_contents)

View File

@@ -8,10 +8,8 @@ import io
from typing import NamedTuple
import pytest
import requests
from onyx.file_store.models import FileDescriptor
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.file import FileManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
@@ -121,31 +119,3 @@ def test_public_assistant_with_user_files(
assert (
len(chat_history) >= 2
), "Expected at least 2 messages (user message and assistant response)"
def test_cannot_download_other_users_file_via_chat_file_endpoint(
user_file_setup: UserFileTestSetup,
) -> None:
storage_file_id = user_file_setup.user1_file_descriptor["id"]
user_file_id = user_file_setup.user1_file_id
owner_response = requests.get(
f"{API_SERVER_URL}/chat/file/{storage_file_id}",
headers=user_file_setup.user1_file_owner.headers,
)
assert owner_response.status_code == 200
assert owner_response.content, "Owner should receive the file contents"
for file_id in (storage_file_id, user_file_id):
user2_response = requests.get(
f"{API_SERVER_URL}/chat/file/{file_id}",
headers=user_file_setup.user2_non_owner.headers,
)
assert user2_response.status_code in (
403,
404,
), (
f"Expected access denied for non-owner, got {user2_response.status_code} "
f"when fetching file_id={file_id}"
)
assert user2_response.content != owner_response.content

View File

@@ -1,78 +0,0 @@
"""Unit tests for the require-score check in verify_captcha_token."""
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.auth import captcha as captcha_module
from onyx.auth.captcha import CaptchaVerificationError
from onyx.auth.captcha import verify_captcha_token
def _fake_httpx_client_returning(payload: dict) -> MagicMock:
resp = MagicMock()
resp.raise_for_status = MagicMock()
resp.json = MagicMock(return_value=payload)
client = MagicMock()
client.post = AsyncMock(return_value=resp)
client.__aenter__ = AsyncMock(return_value=client)
client.__aexit__ = AsyncMock(return_value=None)
return client
@pytest.mark.asyncio
async def test_rejects_when_score_missing() -> None:
"""Siteverify response with no score field is rejected outright —
closes the accidental 'test secret in prod' bypass path."""
client = _fake_httpx_client_returning(
{"success": True, "hostname": "testkey.google.com"}
)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
):
with pytest.raises(CaptchaVerificationError, match="missing score"):
await verify_captcha_token("test-token", expected_action="signup")
@pytest.mark.asyncio
async def test_accepts_when_score_present_and_above_threshold() -> None:
"""Sanity check the happy path still works with the tighter score rule."""
client = _fake_httpx_client_returning(
{
"success": True,
"score": 0.9,
"action": "signup",
"hostname": "cloud.onyx.app",
}
)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
):
# Should not raise.
await verify_captcha_token("fresh-token", expected_action="signup")
@pytest.mark.asyncio
async def test_rejects_when_score_below_threshold() -> None:
"""A score present but below threshold still rejects (existing behavior,
guarding against regression from this PR's restructure)."""
client = _fake_httpx_client_returning(
{
"success": True,
"score": 0.1,
"action": "signup",
"hostname": "cloud.onyx.app",
}
)
with (
patch.object(captcha_module, "is_captcha_enabled", return_value=True),
patch.object(captcha_module.httpx, "AsyncClient", return_value=client),
):
with pytest.raises(
CaptchaVerificationError, match="suspicious activity detected"
):
await verify_captcha_token("low-score-token", expected_action="signup")

View File

@@ -69,9 +69,6 @@ class TestGongConnectorCheckpoint:
workspace_ids=["ws1", None],
workspace_index=1,
cursor="abc123",
pending_transcripts={"call1": _make_transcript("call1")},
pending_call_details_attempts=2,
pending_retry_after=1234567890.5,
)
json_str = original.model_dump_json()
restored = connector.validate_checkpoint_json(json_str)
@@ -234,11 +231,7 @@ class TestGongConnectorCheckpoint:
mock_request: MagicMock,
connector: GongConnector,
) -> None:
"""Missing call details persist across checkpoint invocations and
eventually yield ConnectorFailure once MAX_CALL_DETAILS_ATTEMPTS is hit.
No in-call sleep — retries happen on subsequent invocations, gated by
the wall-clock retry-after deadline on the checkpoint.
"""
"""When call details are missing after retries, yield ConnectorFailure."""
transcript_response = MagicMock()
transcript_response.status_code = 200
transcript_response.json.return_value = {
@@ -264,42 +257,23 @@ class TestGongConnectorCheckpoint:
failures: list[ConnectorFailure] = []
docs: list[Document] = []
# Jump the clock past any retry deadline on each invocation so we
# exercise the retry path without real sleeping. The test for the
# backoff-gate itself lives in test_backoff_gate_prevents_retry_too_soon.
fake_now = [1_000_000.0]
def _advance_clock() -> float:
fake_now[0] += 10_000.0
return fake_now[0]
invocation_cap = GongConnector.MAX_CALL_DETAILS_ATTEMPTS + 5
with patch(
"onyx.connectors.gong.connector.time.time", side_effect=_advance_clock
):
for _ in range(invocation_cap):
if not checkpoint.has_more:
break
generator = connector.load_from_checkpoint(0, fake_now[0], checkpoint)
try:
while True:
item = next(generator)
if isinstance(item, ConnectorFailure):
failures.append(item)
elif isinstance(item, Document):
docs.append(item)
except StopIteration as e:
checkpoint = e.value
with patch("onyx.connectors.gong.connector.time.sleep"):
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
try:
while True:
item = next(generator)
if isinstance(item, ConnectorFailure):
failures.append(item)
elif isinstance(item, Document):
docs.append(item)
except StopIteration as e:
checkpoint = e.value
assert len(docs) == 0
assert len(failures) == 1
assert failures[0].failed_document is not None
assert failures[0].failed_document.document_id == "call1"
assert checkpoint.has_more is False
assert checkpoint.pending_transcripts == {}
assert checkpoint.pending_call_details_attempts == 0
assert checkpoint.pending_retry_after is None
assert mock_request.call_count == 1 + GongConnector.MAX_CALL_DETAILS_ATTEMPTS
@patch.object(GongConnector, "_throttled_request")
def test_multi_workspace_iteration(
@@ -407,14 +381,12 @@ class TestGongConnectorCheckpoint:
assert checkpoint.workspace_index == 1
@patch.object(GongConnector, "_throttled_request")
def test_partial_details_defers_and_resolves_next_invocation(
def test_retry_only_fetches_missing_ids(
self,
mock_request: MagicMock,
connector: GongConnector,
) -> None:
"""A transcript whose call details are missing gets stashed into
pending_transcripts and resolves on a later checkpoint invocation.
Resolved docs are yielded in the order they become available."""
"""Retry for missing call details should only re-request the missing IDs."""
transcript_response = MagicMock()
transcript_response.status_code = 200
transcript_response.json.return_value = {
@@ -432,7 +404,7 @@ class TestGongConnectorCheckpoint:
"calls": [_make_call_detail("call1", "Call One")]
}
# Second fetch (next invocation): returns call2
# Second fetch (retry): returns call2
missing_details = MagicMock()
missing_details.status_code = 200
missing_details.json.return_value = {
@@ -452,48 +424,19 @@ class TestGongConnectorCheckpoint:
)
docs: list[Document] = []
fake_now = [1_000_000.0]
def _advance_clock() -> float:
fake_now[0] += 10_000.0
return fake_now[0]
with patch(
"onyx.connectors.gong.connector.time.time", side_effect=_advance_clock
):
# Invocation 1: fetches page + details, yields call1, stashes call2
generator = connector.load_from_checkpoint(0, fake_now[0], checkpoint)
with patch("onyx.connectors.gong.connector.time.sleep"):
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
try:
while True:
item = next(generator)
if isinstance(item, Document):
docs.append(item)
except StopIteration as e:
checkpoint = e.value
assert len(docs) == 1
assert docs[0].semantic_identifier == "Call One"
assert "call2" in checkpoint.pending_transcripts
assert checkpoint.pending_call_details_attempts == 1
assert checkpoint.pending_retry_after is not None
assert checkpoint.has_more is True
# Invocation 2: retries missing (only call2), yields it, clears pending
generator = connector.load_from_checkpoint(0, fake_now[0], checkpoint)
try:
while True:
item = next(generator)
if isinstance(item, Document):
docs.append(item)
except StopIteration as e:
checkpoint = e.value
except StopIteration:
pass
assert len(docs) == 2
assert docs[0].semantic_identifier == "Call One"
assert docs[1].semantic_identifier == "Call Two"
assert checkpoint.pending_transcripts == {}
assert checkpoint.pending_call_details_attempts == 0
assert checkpoint.pending_retry_after is None
# Verify: 3 API calls total (1 transcript + 1 full details + 1 retry for missing only)
assert mock_request.call_count == 3
@@ -501,107 +444,6 @@ class TestGongConnectorCheckpoint:
retry_call_body = mock_request.call_args_list[2][1]["json"]
assert retry_call_body["filter"]["callIds"] == ["call2"]
@patch.object(GongConnector, "_throttled_request")
def test_backoff_gate_prevents_retry_too_soon(
self,
mock_request: MagicMock,
connector: GongConnector,
) -> None:
"""If the retry-after deadline hasn't elapsed, _resolve_pending must
NOT issue a /v2/calls/extensive request. Prevents burning through
MAX_CALL_DETAILS_ATTEMPTS when workers re-invoke tightly.
"""
pending_transcript = _make_transcript("call1")
fixed_now = 1_000_000.0
# Deadline is 30s in the future from fixed_now
retry_after = fixed_now + 30
checkpoint = GongConnectorCheckpoint(
has_more=True,
workspace_ids=[None],
workspace_index=0,
pending_transcripts={"call1": pending_transcript},
pending_call_details_attempts=1,
pending_retry_after=retry_after,
)
with patch("onyx.connectors.gong.connector.time.time", return_value=fixed_now):
generator = connector.load_from_checkpoint(0, fixed_now, checkpoint)
try:
while True:
next(generator)
except StopIteration as e:
checkpoint = e.value
# No API calls should have been made — we were inside the backoff window
mock_request.assert_not_called()
# Pending state preserved for later retry
assert "call1" in checkpoint.pending_transcripts
assert checkpoint.pending_call_details_attempts == 1
assert checkpoint.pending_retry_after == retry_after
assert checkpoint.has_more is True
@patch.object(GongConnector, "_throttled_request")
def test_pending_retry_does_not_block_on_time_sleep(
self,
mock_request: MagicMock,
connector: GongConnector,
) -> None:
"""Pending-transcript retry must never call time.sleep() with a
non-trivial delay — spacing between retries is enforced via the
wall-clock retry-after deadline stored on the checkpoint, not by
blocking inside load_from_checkpoint.
"""
transcript_response = MagicMock()
transcript_response.status_code = 200
transcript_response.json.return_value = {
"callTranscripts": [_make_transcript("call1")],
"records": {},
}
empty_details = MagicMock()
empty_details.status_code = 200
empty_details.json.return_value = {"calls": []}
mock_request.side_effect = [transcript_response] + [
empty_details
] * GongConnector.MAX_CALL_DETAILS_ATTEMPTS
checkpoint = GongConnectorCheckpoint(
has_more=True,
workspace_ids=[None],
workspace_index=0,
)
fake_now = [1_000_000.0]
def _advance_clock() -> float:
fake_now[0] += 10_000.0
return fake_now[0]
with (
patch("onyx.connectors.gong.connector.time.sleep") as mock_sleep,
patch(
"onyx.connectors.gong.connector.time.time", side_effect=_advance_clock
),
):
invocation_cap = GongConnector.MAX_CALL_DETAILS_ATTEMPTS + 5
for _ in range(invocation_cap):
if not checkpoint.has_more:
break
generator = connector.load_from_checkpoint(0, fake_now[0], checkpoint)
try:
while True:
next(generator)
except StopIteration as e:
checkpoint = e.value
# The only legitimate sleep is the sub-second throttle in
# _throttled_request (<= MIN_REQUEST_INTERVAL). Assert we never
# sleep for anything close to the per-retry backoff delays.
for call in mock_sleep.call_args_list:
delay_arg = call.args[0] if call.args else 0
assert delay_arg <= GongConnector.MIN_REQUEST_INTERVAL
@patch.object(GongConnector, "_throttled_request")
def test_expired_cursor_restarts_workspace(
self,

View File

@@ -287,140 +287,3 @@ class TestFailedFolderIdsByEmail:
)
assert len(failed_map) == 0
class TestOrphanedPathBackfill:
def _make_failed_map(
self, entries: dict[str, set[str]]
) -> ThreadSafeDict[str, ThreadSafeSet[str]]:
return ThreadSafeDict({k: ThreadSafeSet(v) for k, v in entries.items()})
def _make_file(self, parent_id: str) -> MagicMock:
file = MagicMock()
file.user_email = "retriever@example.com"
file.drive_file = {"parents": [parent_id]}
return file
def test_backfills_intermediate_folders_into_failed_map(self) -> None:
"""When a walk dead-ends at a confirmed orphan, all intermediate folder
IDs must be added to failed_folder_ids_by_email for both emails so
future files short-circuit via _get_folder_metadata's cache check."""
connector = _make_connector()
# Chain: folderA -> folderB -> folderC (confirmed orphan)
failed_map = self._make_failed_map(
{
"retriever@example.com": {"folderC"},
"admin@example.com": {"folderC"},
}
)
folder_a = {"id": "folderA", "name": "A", "parents": ["folderB"]}
folder_b = {"id": "folderB", "name": "B", "parents": ["folderC"]}
def mock_get_folder(
_service: MagicMock, folder_id: str, _field_type: DriveFileFieldType
) -> dict | None:
if folder_id == "folderA":
return folder_a
if folder_id == "folderB":
return folder_b
return None
with (
patch(
"onyx.connectors.google_drive.connector.get_drive_service",
return_value=MagicMock(),
),
patch(
"onyx.connectors.google_drive.connector.get_folder_metadata",
side_effect=mock_get_folder,
),
):
connector._get_new_ancestors_for_files(
files=[self._make_file("folderA")],
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
failed_folder_ids_by_email=failed_map,
)
# Both emails confirmed folderC as orphan, so both get the backfill
for email in ("retriever@example.com", "admin@example.com"):
cached = failed_map.get(email, ThreadSafeSet())
assert "folderA" in cached
assert "folderB" in cached
assert "folderC" in cached
def test_backfills_only_for_confirming_email(self) -> None:
"""Only the email that confirmed the orphan gets the path backfilled."""
connector = _make_connector()
# Only retriever confirmed folderC as orphan; admin has no entry
failed_map = self._make_failed_map({"retriever@example.com": {"folderC"}})
folder_a = {"id": "folderA", "name": "A", "parents": ["folderB"]}
folder_b = {"id": "folderB", "name": "B", "parents": ["folderC"]}
def mock_get_folder(
_service: MagicMock, folder_id: str, _field_type: DriveFileFieldType
) -> dict | None:
if folder_id == "folderA":
return folder_a
if folder_id == "folderB":
return folder_b
return None
with (
patch(
"onyx.connectors.google_drive.connector.get_drive_service",
return_value=MagicMock(),
),
patch(
"onyx.connectors.google_drive.connector.get_folder_metadata",
side_effect=mock_get_folder,
),
):
connector._get_new_ancestors_for_files(
files=[self._make_file("folderA")],
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
failed_folder_ids_by_email=failed_map,
)
retriever_cached = failed_map.get("retriever@example.com", ThreadSafeSet())
assert "folderA" in retriever_cached
assert "folderB" in retriever_cached
# admin did not confirm the orphan — must not get the backfill
assert failed_map.get("admin@example.com") is None
def test_short_circuits_on_backfilled_intermediate(self) -> None:
"""A second file whose parent is already in failed_folder_ids_by_email
must not trigger any folder metadata API calls."""
connector = _make_connector()
# folderA already in the failed map from a previous walk
failed_map = self._make_failed_map(
{
"retriever@example.com": {"folderA"},
"admin@example.com": {"folderA"},
}
)
with (
patch(
"onyx.connectors.google_drive.connector.get_drive_service",
return_value=MagicMock(),
),
patch(
"onyx.connectors.google_drive.connector.get_folder_metadata"
) as mock_api,
):
connector._get_new_ancestors_for_files(
files=[self._make_file("folderA")],
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
failed_folder_ids_by_email=failed_map,
)
mock_api.assert_not_called()

View File

@@ -1,101 +0,0 @@
"""Unit tests for SharepointConnector.load_credentials sp_tenant_domain resolution."""
from __future__ import annotations
import base64
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.sharepoint.connector import SharepointConnector
SITE_URL = "https://mytenant.sharepoint.com/sites/MySite"
EXPECTED_TENANT_DOMAIN = "mytenant"
CLIENT_SECRET_CREDS = {
"authentication_method": "client_secret",
"sp_client_id": "fake-client-id",
"sp_client_secret": "fake-client-secret",
"sp_directory_id": "fake-directory-id",
}
CERTIFICATE_CREDS = {
"authentication_method": "certificate",
"sp_client_id": "fake-client-id",
"sp_directory_id": "fake-directory-id",
"sp_private_key": base64.b64encode(b"fake-pfx-data").decode(),
"sp_certificate_password": "fake-password",
}
def _make_mock_msal() -> MagicMock:
mock_app = MagicMock()
mock_app.acquire_token_for_client.return_value = {"access_token": "fake-token"}
return mock_app
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
@patch("onyx.connectors.sharepoint.connector.GraphClient")
def test_client_secret_with_site_pages_sets_tenant_domain(
_mock_graph_client: MagicMock,
mock_msal_cls: MagicMock,
) -> None:
"""client_secret auth + include_site_pages=True must resolve sp_tenant_domain."""
mock_msal_cls.return_value = _make_mock_msal()
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=True)
connector.load_credentials(CLIENT_SECRET_CREDS)
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
@patch("onyx.connectors.sharepoint.connector.GraphClient")
def test_client_secret_without_site_pages_still_sets_tenant_domain(
_mock_graph_client: MagicMock,
mock_msal_cls: MagicMock,
) -> None:
"""client_secret auth + include_site_pages=False must still resolve sp_tenant_domain
because _create_rest_client_context is also called for drive items."""
mock_msal_cls.return_value = _make_mock_msal()
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=False)
connector.load_credentials(CLIENT_SECRET_CREDS)
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
@patch("onyx.connectors.sharepoint.connector.load_certificate_from_pfx")
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
@patch("onyx.connectors.sharepoint.connector.GraphClient")
def test_certificate_with_site_pages_sets_tenant_domain(
_mock_graph_client: MagicMock,
mock_msal_cls: MagicMock,
mock_load_cert: MagicMock,
) -> None:
"""certificate auth + include_site_pages=True must resolve sp_tenant_domain."""
mock_msal_cls.return_value = _make_mock_msal()
mock_load_cert.return_value = MagicMock()
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=True)
connector.load_credentials(CERTIFICATE_CREDS)
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
@patch("onyx.connectors.sharepoint.connector.load_certificate_from_pfx")
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
@patch("onyx.connectors.sharepoint.connector.GraphClient")
def test_certificate_without_site_pages_sets_tenant_domain(
_mock_graph_client: MagicMock,
mock_msal_cls: MagicMock,
mock_load_cert: MagicMock,
) -> None:
"""certificate auth + include_site_pages=False must still resolve sp_tenant_domain
because _create_rest_client_context is also called for drive items."""
mock_msal_cls.return_value = _make_mock_msal()
mock_load_cert.return_value = MagicMock()
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=False)
connector.load_credentials(CERTIFICATE_CREDS)
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN

View File

@@ -1,194 +0,0 @@
"""Unit tests for SharepointConnector site-page slim resilience and
validate_connector_settings RoleAssignments permission probe."""
from __future__ import annotations
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.sharepoint.connector import SharepointConnector
SITE_URL = "https://tenant.sharepoint.com/sites/MySite"
def _make_connector() -> SharepointConnector:
connector = SharepointConnector(sites=[SITE_URL])
connector.msal_app = MagicMock()
connector.sp_tenant_domain = "tenant"
connector._credential_json = {"sp_client_id": "x", "sp_directory_id": "y"}
connector._graph_client = MagicMock()
return connector
# ---------------------------------------------------------------------------
# _fetch_slim_documents_from_sharepoint — site page error resilience
# ---------------------------------------------------------------------------
@patch("onyx.connectors.sharepoint.connector._convert_sitepage_to_slim_document")
@patch(
"onyx.connectors.sharepoint.connector.SharepointConnector._create_rest_client_context"
)
@patch("onyx.connectors.sharepoint.connector.SharepointConnector._fetch_site_pages")
@patch("onyx.connectors.sharepoint.connector.SharepointConnector._fetch_driveitems")
@patch("onyx.connectors.sharepoint.connector.SharepointConnector.fetch_sites")
def test_site_page_error_does_not_crash(
mock_fetch_sites: MagicMock,
mock_fetch_driveitems: MagicMock,
mock_fetch_site_pages: MagicMock,
_mock_create_ctx: MagicMock,
mock_convert: MagicMock,
) -> None:
"""A 401 (or any exception) on a site page is caught; remaining pages are processed."""
from onyx.connectors.models import SlimDocument
connector = _make_connector()
connector.include_site_documents = False
connector.include_site_pages = True
site = MagicMock()
site.url = SITE_URL
mock_fetch_sites.return_value = [site]
mock_fetch_driveitems.return_value = iter([])
page_ok = {"id": "1", "webUrl": SITE_URL + "/SitePages/Good.aspx"}
page_bad = {"id": "2", "webUrl": SITE_URL + "/SitePages/Bad.aspx"}
mock_fetch_site_pages.return_value = [page_bad, page_ok]
good_slim = SlimDocument(id="1")
def _convert_side_effect(
page: dict, *_args: object, **_kwargs: object
) -> SlimDocument: # noqa: ANN001
if page["id"] == "2":
from office365.runtime.client_request import ClientRequestException
raise ClientRequestException(MagicMock(status_code=401), None)
return good_slim
mock_convert.side_effect = _convert_side_effect
results = [
doc
for batch in connector._fetch_slim_documents_from_sharepoint()
for doc in batch
if isinstance(doc, SlimDocument)
]
# Only the good page makes it through; bad page is skipped, no exception raised.
assert any(d.id == "1" for d in results)
assert not any(d.id == "2" for d in results)
@patch("onyx.connectors.sharepoint.connector._convert_sitepage_to_slim_document")
@patch(
"onyx.connectors.sharepoint.connector.SharepointConnector._create_rest_client_context"
)
@patch("onyx.connectors.sharepoint.connector.SharepointConnector._fetch_site_pages")
@patch("onyx.connectors.sharepoint.connector.SharepointConnector._fetch_driveitems")
@patch("onyx.connectors.sharepoint.connector.SharepointConnector.fetch_sites")
def test_all_site_pages_fail_does_not_crash(
mock_fetch_sites: MagicMock,
mock_fetch_driveitems: MagicMock,
mock_fetch_site_pages: MagicMock,
_mock_create_ctx: MagicMock,
mock_convert: MagicMock,
) -> None:
"""When every site page fails, the generator completes without raising."""
connector = _make_connector()
connector.include_site_documents = False
connector.include_site_pages = True
site = MagicMock()
site.url = SITE_URL
mock_fetch_sites.return_value = [site]
mock_fetch_driveitems.return_value = iter([])
mock_fetch_site_pages.return_value = [
{"id": "1", "webUrl": SITE_URL + "/SitePages/A.aspx"},
{"id": "2", "webUrl": SITE_URL + "/SitePages/B.aspx"},
]
mock_convert.side_effect = RuntimeError("context error")
from onyx.connectors.models import SlimDocument
# Should not raise; no SlimDocuments in output (only hierarchy nodes).
slim_results = [
doc
for batch in connector._fetch_slim_documents_from_sharepoint()
for doc in batch
if isinstance(doc, SlimDocument)
]
assert slim_results == []
# ---------------------------------------------------------------------------
# validate_connector_settings — RoleAssignments permission probe
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("status_code", [401, 403])
@patch("onyx.connectors.sharepoint.connector.requests.get")
@patch("onyx.connectors.sharepoint.connector.validate_outbound_http_url")
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
def test_validate_raises_on_401_or_403(
mock_acquire: MagicMock,
_mock_validate_url: MagicMock,
mock_get: MagicMock,
status_code: int,
) -> None:
"""validate_connector_settings raises ConnectorValidationError when probe returns 401 or 403."""
mock_acquire.return_value = MagicMock(accessToken="tok")
mock_get.return_value = MagicMock(status_code=status_code)
connector = _make_connector()
with pytest.raises(ConnectorValidationError, match="Sites.FullControl.All"):
connector.validate_connector_settings()
@patch("onyx.connectors.sharepoint.connector.requests.get")
@patch("onyx.connectors.sharepoint.connector.validate_outbound_http_url")
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
def test_validate_passes_on_200(
mock_acquire: MagicMock,
_mock_validate_url: MagicMock,
mock_get: MagicMock,
) -> None:
"""validate_connector_settings does not raise when probe returns 200."""
mock_acquire.return_value = MagicMock(accessToken="tok")
mock_get.return_value = MagicMock(status_code=200)
connector = _make_connector()
connector.validate_connector_settings() # should not raise
@patch("onyx.connectors.sharepoint.connector.requests.get")
@patch("onyx.connectors.sharepoint.connector.validate_outbound_http_url")
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
def test_validate_passes_on_network_error(
mock_acquire: MagicMock,
_mock_validate_url: MagicMock,
mock_get: MagicMock,
) -> None:
"""Network errors during the probe are non-blocking (logged as warning only)."""
mock_acquire.return_value = MagicMock(accessToken="tok")
mock_get.side_effect = Exception("timeout")
connector = _make_connector()
connector.validate_connector_settings() # should not raise
@patch("onyx.connectors.sharepoint.connector.validate_outbound_http_url")
@patch("onyx.connectors.sharepoint.connector.acquire_token_for_rest")
def test_validate_skips_probe_without_credentials(
mock_acquire: MagicMock,
_mock_validate_url: MagicMock,
) -> None:
"""Probe is skipped when credentials have not been loaded."""
connector = SharepointConnector(sites=[SITE_URL])
# msal_app and sp_tenant_domain are None — probe must be skipped.
connector.validate_connector_settings() # should not raise
mock_acquire.assert_not_called()

View File

@@ -1,243 +0,0 @@
"""Unit tests for WebConnector.retrieve_all_slim_docs (slim pruning path)."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.models import SlimDocument
from onyx.connectors.web.connector import WEB_CONNECTOR_VALID_SETTINGS
from onyx.connectors.web.connector import WebConnector
BASE_URL = "http://example.com"
SINGLE_PAGE_HTML = (
"<html><body><p>Content that should not appear in slim output</p></body></html>"
)
RECURSIVE_ROOT_HTML = """
<html><body>
<a href="/page2">Page 2</a>
<a href="/page3">Page 3</a>
</body></html>
"""
PAGE2_HTML = "<html><body><p>page 2</p></body></html>"
PAGE3_HTML = "<html><body><p>page 3</p></body></html>"
def _make_playwright_context_mock(url_to_html: dict[str, str]) -> MagicMock:
"""Return a BrowserContext mock whose pages respond based on goto URL."""
context = MagicMock()
def _new_page() -> MagicMock:
page = MagicMock()
visited: list[str] = []
def _goto(url: str, **kwargs: Any) -> MagicMock: # noqa: ARG001
visited.append(url)
page.url = url
response = MagicMock()
response.status = 200
response.header_value.return_value = None # no cf-ray
return response
def _content() -> str:
return url_to_html.get(
visited[-1] if visited else "", "<html><body></body></html>"
)
page.goto.side_effect = _goto
page.content.side_effect = _content
return page
context.new_page.side_effect = _new_page
return context
def _make_playwright_mock() -> MagicMock:
playwright = MagicMock()
playwright.stop = MagicMock()
return playwright
def _make_page_mock(
html: str, cf_ray: str | None = None, status: int = 200
) -> MagicMock:
"""Return a Playwright page mock with configurable status and CF header."""
page = MagicMock()
page.url = BASE_URL + "/"
response = MagicMock()
response.status = status
response.header_value.side_effect = lambda h: cf_ray if h == "cf-ray" else None
page.goto.return_value = response
page.content.return_value = html
return page
@patch("onyx.connectors.web.connector.check_internet_connection")
@patch("onyx.connectors.web.connector.requests.head")
@patch("onyx.connectors.web.connector.start_playwright")
def test_slim_yields_slim_documents(
mock_start_playwright: MagicMock,
mock_head: MagicMock,
_mock_check: MagicMock,
) -> None:
"""retrieve_all_slim_docs yields SlimDocuments with the correct URL as id."""
context = _make_playwright_context_mock({BASE_URL + "/": SINGLE_PAGE_HTML})
mock_start_playwright.return_value = (_make_playwright_mock(), context)
mock_head.return_value.headers = {"content-type": "text/html"}
connector = WebConnector(
base_url=BASE_URL + "/",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
)
docs = [doc for batch in connector.retrieve_all_slim_docs() for doc in batch]
assert len(docs) == 1
assert isinstance(docs[0], SlimDocument)
assert docs[0].id == BASE_URL + "/"
@patch("onyx.connectors.web.connector.check_internet_connection")
@patch("onyx.connectors.web.connector.requests.head")
@patch("onyx.connectors.web.connector.start_playwright")
def test_slim_skips_content_extraction(
mock_start_playwright: MagicMock,
mock_head: MagicMock,
_mock_check: MagicMock,
) -> None:
"""web_html_cleanup is never called in slim mode."""
context = _make_playwright_context_mock({BASE_URL + "/": SINGLE_PAGE_HTML})
mock_start_playwright.return_value = (_make_playwright_mock(), context)
mock_head.return_value.headers = {"content-type": "text/html"}
connector = WebConnector(
base_url=BASE_URL + "/",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
)
with patch("onyx.connectors.web.connector.web_html_cleanup") as mock_cleanup:
list(connector.retrieve_all_slim_docs())
mock_cleanup.assert_not_called()
@patch("onyx.connectors.web.connector.check_internet_connection")
@patch("onyx.connectors.web.connector.requests.head")
@patch("onyx.connectors.web.connector.start_playwright")
def test_slim_discovers_links_recursively(
mock_start_playwright: MagicMock,
mock_head: MagicMock,
_mock_check: MagicMock,
) -> None:
"""In RECURSIVE mode, internal <a href> links are followed and all URLs yielded."""
url_to_html = {
BASE_URL + "/": RECURSIVE_ROOT_HTML,
BASE_URL + "/page2": PAGE2_HTML,
BASE_URL + "/page3": PAGE3_HTML,
}
context = _make_playwright_context_mock(url_to_html)
mock_start_playwright.return_value = (_make_playwright_mock(), context)
mock_head.return_value.headers = {"content-type": "text/html"}
connector = WebConnector(
base_url=BASE_URL + "/",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value,
)
ids = {
doc.id
for batch in connector.retrieve_all_slim_docs()
for doc in batch
if isinstance(doc, SlimDocument)
}
assert ids == {
BASE_URL + "/",
BASE_URL + "/page2",
BASE_URL + "/page3",
}
@patch("onyx.connectors.web.connector.check_internet_connection")
@patch("onyx.connectors.web.connector.requests.head")
@patch("onyx.connectors.web.connector.start_playwright")
def test_normal_200_skips_5s_wait(
mock_start_playwright: MagicMock,
mock_head: MagicMock,
_mock_check: MagicMock,
) -> None:
"""Normal 200 responses without bot-detection signals skip the 5s render wait."""
page = _make_page_mock(SINGLE_PAGE_HTML, cf_ray=None, status=200)
context = MagicMock()
context.new_page.return_value = page
mock_start_playwright.return_value = (_make_playwright_mock(), context)
mock_head.return_value.headers = {"content-type": "text/html"}
connector = WebConnector(
base_url=BASE_URL + "/",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
)
list(connector.retrieve_all_slim_docs())
page.wait_for_timeout.assert_not_called()
@patch("onyx.connectors.web.connector.check_internet_connection")
@patch("onyx.connectors.web.connector.requests.head")
@patch("onyx.connectors.web.connector.start_playwright")
def test_cloudflare_applies_5s_wait(
mock_start_playwright: MagicMock,
mock_head: MagicMock,
_mock_check: MagicMock,
) -> None:
"""Pages with a cf-ray header trigger the 5s wait before networkidle."""
page = _make_page_mock(SINGLE_PAGE_HTML, cf_ray="abc123-LAX")
context = MagicMock()
context.new_page.return_value = page
mock_start_playwright.return_value = (_make_playwright_mock(), context)
mock_head.return_value.headers = {"content-type": "text/html"}
connector = WebConnector(
base_url=BASE_URL + "/",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
)
list(connector.retrieve_all_slim_docs())
page.wait_for_timeout.assert_called_once_with(5000)
@patch("onyx.connectors.web.connector.time")
@patch("onyx.connectors.web.connector.check_internet_connection")
@patch("onyx.connectors.web.connector.requests.head")
@patch("onyx.connectors.web.connector.start_playwright")
def test_403_applies_5s_wait(
mock_start_playwright: MagicMock,
mock_head: MagicMock,
_mock_check: MagicMock,
_mock_time: MagicMock,
) -> None:
"""A 403 response triggers the 5s wait (common bot-detection challenge entry point)."""
page = _make_page_mock(SINGLE_PAGE_HTML, cf_ray=None, status=403)
context = MagicMock()
context.new_page.return_value = page
mock_start_playwright.return_value = (_make_playwright_mock(), context)
mock_head.return_value.headers = {"content-type": "text/html"}
connector = WebConnector(
base_url=BASE_URL + "/",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
)
# All retries return 403 so no docs are found — that's expected here.
# We only care that the 5s wait fired.
try:
list(connector.retrieve_all_slim_docs())
except RuntimeError:
pass
page.wait_for_timeout.assert_called_with(5000)

View File

@@ -1,11 +1,12 @@
import io
from typing import cast
from unittest.mock import patch
from unittest.mock import MagicMock
import openpyxl
from openpyxl.worksheet.worksheet import Worksheet
from onyx.file_processing.extract_file_text import _sheet_to_csv
from onyx.file_processing.extract_file_text import _clean_worksheet_matrix
from onyx.file_processing.extract_file_text import _worksheet_to_matrix
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
from onyx.file_processing.extract_file_text import xlsx_to_text
@@ -201,179 +202,50 @@ class TestXlsxToText:
assert "r3c1" in lines[2] and "r3c2" in lines[2]
class TestSheetToCsvJaggedRows:
"""openpyxl's read-only mode yields rows of differing widths when
trailing cells are empty. These tests exercise ``_sheet_to_csv``
directly because ``_make_xlsx`` (via ``ws.append``) normalizes row
widths, so jagged input can only be produced in-memory."""
class TestWorksheetToMatrixJaggedRows:
"""openpyxl read_only mode can yield rows of differing widths when
trailing cells are empty. The matrix must be padded to a rectangle
so downstream column cleanup can index safely."""
def test_shorter_trailing_rows_padded_in_output(self) -> None:
csv_text = _sheet_to_csv(
iter(
[
("A", "B", "C"),
("X", "Y"),
("P",),
]
)
def test_pads_shorter_trailing_rows(self) -> None:
ws = MagicMock()
ws.iter_rows.return_value = iter(
[
("A", "B", "C"),
("X", "Y"),
("P",),
]
)
assert csv_text.split("\n") == ["A,B,C", "X,Y,", "P,,"]
matrix = _worksheet_to_matrix(ws)
assert matrix == [["A", "B", "C"], ["X", "Y", ""], ["P", "", ""]]
def test_shorter_leading_row_padded_in_output(self) -> None:
csv_text = _sheet_to_csv(
iter(
[
("A",),
("X", "Y", "Z"),
]
)
def test_pads_when_first_row_is_shorter(self) -> None:
ws = MagicMock()
ws.iter_rows.return_value = iter(
[
("A",),
("X", "Y", "Z"),
]
)
assert csv_text.split("\n") == ["A,,", "X,Y,Z"]
matrix = _worksheet_to_matrix(ws)
assert matrix == [["A", "", ""], ["X", "Y", "Z"]]
def test_no_index_error_on_jagged_rows(self) -> None:
"""Regression: the original dense-matrix version raised IndexError
when a later row was shorter than an earlier row whose out-of-range
columns happened to be empty."""
csv_text = _sheet_to_csv(
iter(
[
("A", "", "", "B"),
("X", "Y"),
]
)
def test_clean_worksheet_matrix_no_index_error_on_jagged_rows(self) -> None:
"""Regression: previously raised IndexError when a later row was
shorter than the first row and the out-of-range column on the
first row was empty (so the short-circuit in `all()` did not
save us)."""
ws = MagicMock()
ws.iter_rows.return_value = iter(
[
("A", "", "", "B"),
("X", "Y"),
]
)
assert csv_text.split("\n") == ["A,,,B", "X,Y,,"]
class TestSheetToCsvStreaming:
"""Pin the memory-safe streaming contract: empty rows are skipped
cheaply, empty-row/column runs are collapsed to at most 2, and sheets
with no data return the empty string."""
def test_empty_rows_between_data_capped_at_two(self) -> None:
csv_text = _sheet_to_csv(
iter(
[
("A", "B"),
(None, None),
(None, None),
(None, None),
(None, None),
(None, None),
("C", "D"),
]
)
)
# 5 empty rows collapsed to 2
assert csv_text.split("\n") == ["A,B", ",", ",", "C,D"]
def test_empty_rows_at_or_below_cap_preserved(self) -> None:
csv_text = _sheet_to_csv(
iter(
[
("A", "B"),
(None, None),
(None, None),
("C", "D"),
]
)
)
assert csv_text.split("\n") == ["A,B", ",", ",", "C,D"]
def test_empty_column_run_capped_at_two(self) -> None:
csv_text = _sheet_to_csv(
iter(
[
("A", None, None, None, None, "B"),
("C", None, None, None, None, "D"),
]
)
)
# 4 empty cols between A and B collapsed to 2
assert csv_text.split("\n") == ["A,,,B", "C,,,D"]
def test_completely_empty_stream_returns_empty_string(self) -> None:
assert _sheet_to_csv(iter([])) == ""
def test_all_rows_empty_returns_empty_string(self) -> None:
csv_text = _sheet_to_csv(
iter(
[
(None, None),
("", ""),
(None,),
]
)
)
assert csv_text == ""
def test_trailing_empty_rows_dropped(self) -> None:
csv_text = _sheet_to_csv(
iter(
[
("A",),
("B",),
(None,),
(None,),
(None,),
]
)
)
# Trailing empties are never emitted (no subsequent non-empty row
# to flush them against).
assert csv_text.split("\n") == ["A", "B"]
def test_leading_empty_rows_capped_at_two(self) -> None:
csv_text = _sheet_to_csv(
iter(
[
(None, None),
(None, None),
(None, None),
(None, None),
(None, None),
("A", "B"),
]
)
)
# 5 leading empty rows collapsed to 2
assert csv_text.split("\n") == [",", ",", "A,B"]
def test_cell_cap_truncates_and_appends_marker(self) -> None:
"""When total non-empty cells exceeds the cap, scanning stops and
a truncation marker row is appended so downstream indexing sees
the sheet was cut off."""
with patch(
"onyx.file_processing.extract_file_text.MAX_XLSX_CELLS_PER_SHEET", 5
):
csv_text = _sheet_to_csv(
iter(
[
("A", "B", "C"),
("D", "E", "F"),
("G", "H", "I"),
("J", "K", "L"),
]
)
)
lines = csv_text.split("\n")
assert lines[-1] == "[truncated: sheet exceeded cell limit]"
# First two rows (6 cells) trip the cap=5 check after row 2; the
# third and fourth rows are never scanned.
assert "G" not in csv_text
assert "J" not in csv_text
def test_cell_cap_not_hit_no_marker(self) -> None:
"""Under the cap, no truncation marker is appended."""
csv_text = _sheet_to_csv(
iter(
[
("A", "B"),
("C", "D"),
]
)
)
assert "[truncated" not in csv_text
matrix = _worksheet_to_matrix(ws)
# Must not raise.
cleaned = _clean_worksheet_matrix(matrix)
assert cleaned == [["A", "", "", "B"], ["X", "Y", "", ""]]
class TestXlsxSheetExtraction:
@@ -408,9 +280,10 @@ class TestXlsxSheetExtraction:
assert "a" in csv_text
assert "b" in csv_text
def test_empty_sheet_included_with_empty_csv(self) -> None:
"""Every sheet in the workbook appears in the result; an empty
sheet contributes an empty csv_text alongside its title."""
def test_empty_sheet_is_skipped(self) -> None:
"""A sheet whose CSV output is empty/whitespace-only should NOT
appear in the result — the `if csv_text.strip():` guard filters
it out."""
xlsx = _make_xlsx(
{
"Data": [["a", "b"]],
@@ -418,17 +291,14 @@ class TestXlsxSheetExtraction:
}
)
sheets = xlsx_sheet_extraction(xlsx)
assert len(sheets) == 2
titles = [title for _csv, title in sheets]
assert titles == ["Data", "Empty"]
empty_csv = next(csv_text for csv_text, title in sheets if title == "Empty")
assert empty_csv == ""
assert len(sheets) == 1
assert sheets[0][1] == "Data"
def test_empty_workbook_returns_one_tuple_per_sheet(self) -> None:
"""All sheets empty → one empty-csv tuple per sheet."""
def test_empty_workbook_returns_empty_list(self) -> None:
"""All sheets empty → empty list (not a list of empty tuples)."""
xlsx = _make_xlsx({"Sheet1": [], "Sheet2": []})
sheets = xlsx_sheet_extraction(xlsx)
assert sheets == [("", "Sheet1"), ("", "Sheet2")]
assert sheets == []
def test_single_sheet(self) -> None:
xlsx = _make_xlsx({"Only": [["x", "y"], ["1", "2"]]})
@@ -451,17 +321,6 @@ class TestXlsxSheetExtraction:
sheets = xlsx_sheet_extraction(bad_file, file_name="~$temp.xlsx")
assert sheets == []
def test_known_openpyxl_bug_max_value_returns_empty(self) -> None:
"""openpyxl's strict descriptor validation rejects font family
values >14 with 'Max value is 14'. Treat as a known openpyxl bug
and skip the file rather than fail the whole connector batch."""
with patch(
"onyx.file_processing.extract_file_text.openpyxl.load_workbook",
side_effect=ValueError("Max value is 14"),
):
sheets = xlsx_sheet_extraction(io.BytesIO(b""), file_name="bad_font.xlsx")
assert sheets == []
def test_csv_content_matches_xlsx_to_text_per_sheet(self) -> None:
"""For a single-sheet workbook, xlsx_to_text output should equal
the csv_text from xlsx_sheet_extraction — they share the same

View File

@@ -1,257 +0,0 @@
"""Tests for embedding Prometheus metrics."""
from unittest.mock import patch
from onyx.server.metrics.embedding import _client_duration
from onyx.server.metrics.embedding import _embedding_input_chars_total
from onyx.server.metrics.embedding import _embedding_requests_total
from onyx.server.metrics.embedding import _embedding_texts_total
from onyx.server.metrics.embedding import _embeddings_in_progress
from onyx.server.metrics.embedding import LOCAL_PROVIDER_LABEL
from onyx.server.metrics.embedding import observe_embedding_client
from onyx.server.metrics.embedding import provider_label
from onyx.server.metrics.embedding import PROVIDER_LABEL_NAME
from onyx.server.metrics.embedding import TEXT_TYPE_LABEL_NAME
from onyx.server.metrics.embedding import track_embedding_in_progress
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
class TestProviderLabel:
def test_none_maps_to_local(self) -> None:
assert provider_label(None) == LOCAL_PROVIDER_LABEL
def test_enum_maps_to_value(self) -> None:
assert provider_label(EmbeddingProvider.OPENAI) == "openai"
assert provider_label(EmbeddingProvider.COHERE) == "cohere"
class TestObserveEmbeddingClient:
def test_success_records_all_counters(self) -> None:
# Precondition.
provider = EmbeddingProvider.OPENAI
text_type = EmbedTextType.QUERY
labels = {
PROVIDER_LABEL_NAME: provider.value,
TEXT_TYPE_LABEL_NAME: text_type.value,
}
before_requests = _embedding_requests_total.labels(
**labels, status="success"
)._value.get()
before_texts = _embedding_texts_total.labels(**labels)._value.get()
before_chars = _embedding_input_chars_total.labels(**labels)._value.get()
before_duration_sum = _client_duration.labels(**labels)._sum.get()
test_duration_s = 0.123
test_num_texts = 4
test_num_chars = 200
# Under test.
observe_embedding_client(
provider=provider,
text_type=text_type,
duration_s=test_duration_s,
num_texts=test_num_texts,
num_chars=test_num_chars,
success=True,
)
# Postcondition.
assert (
_embedding_requests_total.labels(**labels, status="success")._value.get()
== before_requests + 1
)
assert (
_embedding_texts_total.labels(**labels)._value.get()
== before_texts + test_num_texts
)
assert (
_embedding_input_chars_total.labels(**labels)._value.get()
== before_chars + test_num_chars
)
assert (
_client_duration.labels(**labels)._sum.get()
== before_duration_sum + test_duration_s
)
def test_failure_records_duration_and_failure_counter_only(self) -> None:
# Precondition.
provider = EmbeddingProvider.COHERE
text_type = EmbedTextType.PASSAGE
labels = {
PROVIDER_LABEL_NAME: provider.value,
TEXT_TYPE_LABEL_NAME: text_type.value,
}
before_failure = _embedding_requests_total.labels(
**labels, status="failure"
)._value.get()
before_texts = _embedding_texts_total.labels(**labels)._value.get()
before_chars = _embedding_input_chars_total.labels(**labels)._value.get()
before_duration_sum = _client_duration.labels(**labels)._sum.get()
test_duration_s = 0.5
test_num_texts = 3
test_num_chars = 150
# Under test.
observe_embedding_client(
provider=provider,
text_type=text_type,
duration_s=test_duration_s,
num_texts=test_num_texts,
num_chars=test_num_chars,
success=False,
)
# Postcondition.
# Failure counter incremented.
assert (
_embedding_requests_total.labels(**labels, status="failure")._value.get()
== before_failure + 1
)
# Duration still recorded.
assert (
_client_duration.labels(**labels)._sum.get()
== before_duration_sum + test_duration_s
)
# Throughput counters NOT bumped on failure.
assert _embedding_texts_total.labels(**labels)._value.get() == before_texts
assert (
_embedding_input_chars_total.labels(**labels)._value.get() == before_chars
)
def test_local_provider_uses_local_label(self) -> None:
# Precondition.
text_type = EmbedTextType.QUERY
labels = {
PROVIDER_LABEL_NAME: LOCAL_PROVIDER_LABEL,
TEXT_TYPE_LABEL_NAME: text_type.value,
}
before = _embedding_requests_total.labels(
**labels, status="success"
)._value.get()
test_duration_s = 0.05
test_num_texts = 1
test_num_chars = 10
# Under test.
observe_embedding_client(
provider=None,
text_type=text_type,
duration_s=test_duration_s,
num_texts=test_num_texts,
num_chars=test_num_chars,
success=True,
)
# Postcondition.
assert (
_embedding_requests_total.labels(**labels, status="success")._value.get()
== before + 1
)
def test_exceptions_do_not_propagate(self) -> None:
with patch.object(
_embedding_requests_total,
"labels",
side_effect=RuntimeError("boom"),
):
# Must not raise.
observe_embedding_client(
provider=EmbeddingProvider.OPENAI,
text_type=EmbedTextType.QUERY,
duration_s=0.1,
num_texts=1,
num_chars=10,
success=True,
)
class TestTrackEmbeddingInProgress:
def test_gauge_increments_and_decrements(self) -> None:
# Precondition.
provider = EmbeddingProvider.OPENAI
text_type = EmbedTextType.QUERY
labels = {
PROVIDER_LABEL_NAME: provider.value,
TEXT_TYPE_LABEL_NAME: text_type.value,
}
before = _embeddings_in_progress.labels(**labels)._value.get()
# Under test.
with track_embedding_in_progress(provider, text_type):
during = _embeddings_in_progress.labels(**labels)._value.get()
assert during == before + 1
# Postcondition.
after = _embeddings_in_progress.labels(**labels)._value.get()
assert after == before
def test_gauge_decrements_on_exception(self) -> None:
# Precondition.
provider = EmbeddingProvider.COHERE
text_type = EmbedTextType.PASSAGE
labels = {
PROVIDER_LABEL_NAME: provider.value,
TEXT_TYPE_LABEL_NAME: text_type.value,
}
before = _embeddings_in_progress.labels(**labels)._value.get()
# Under test.
raised = False
try:
with track_embedding_in_progress(provider, text_type):
raise ValueError("simulated embedding failure")
except ValueError:
raised = True
assert raised
# Postcondition.
after = _embeddings_in_progress.labels(**labels)._value.get()
assert after == before
def test_local_provider_uses_local_label(self) -> None:
# Precondition.
text_type = EmbedTextType.QUERY
labels = {
PROVIDER_LABEL_NAME: LOCAL_PROVIDER_LABEL,
TEXT_TYPE_LABEL_NAME: text_type.value,
}
before = _embeddings_in_progress.labels(**labels)._value.get()
# Under test.
with track_embedding_in_progress(None, text_type):
during = _embeddings_in_progress.labels(**labels)._value.get()
assert during == before + 1
# Postcondition.
after = _embeddings_in_progress.labels(**labels)._value.get()
assert after == before
def test_inc_exception_does_not_break_call(self) -> None:
# Precondition.
provider = EmbeddingProvider.VOYAGE
text_type = EmbedTextType.QUERY
labels = {
PROVIDER_LABEL_NAME: provider.value,
TEXT_TYPE_LABEL_NAME: text_type.value,
}
before = _embeddings_in_progress.labels(**labels)._value.get()
# Under test.
with patch.object(
_embeddings_in_progress.labels(**labels),
"inc",
side_effect=RuntimeError("boom"),
):
# Context manager should still yield without decrementing.
with track_embedding_in_progress(provider, text_type):
during = _embeddings_in_progress.labels(**labels)._value.get()
assert during == before
# Postcondition.
after = _embeddings_in_progress.labels(**labels)._value.get()
assert after == before

View File

@@ -129,36 +129,12 @@ class TestWorkerHealthCollector:
up = families[1]
assert up.name == "onyx_celery_worker_up"
assert len(up.samples) == 3
label_pairs = {
(s.labels["worker_type"], s.labels["hostname"]) for s in up.samples
}
assert label_pairs == {
("primary", "host1"),
("docfetching", "host1"),
("monitoring", "host1"),
}
# Labels use short names (before @)
labels = {s.labels["worker"] for s in up.samples}
assert labels == {"primary", "docfetching", "monitoring"}
for sample in up.samples:
assert sample.value == 1
def test_replicas_of_same_worker_type_are_distinct(self) -> None:
"""Regression: ``docprocessing@pod-1`` and ``docprocessing@pod-2`` must
produce separate samples, not collapse into one duplicate-timestamp
series.
"""
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "docprocessing@pod-1"})
monitor._on_heartbeat({"hostname": "docprocessing@pod-2"})
monitor._on_heartbeat({"hostname": "docprocessing@pod-3"})
collector = WorkerHealthCollector(cache_ttl=0)
collector.set_monitor(monitor)
up = collector.collect()[1]
assert len(up.samples) == 3
hostnames = {s.labels["hostname"] for s in up.samples}
assert hostnames == {"pod-1", "pod-2", "pod-3"}
assert all(s.labels["worker_type"] == "docprocessing" for s in up.samples)
def test_reports_dead_worker(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "primary@host1"})
@@ -175,9 +151,9 @@ class TestWorkerHealthCollector:
assert active.samples[0].value == 1
up = families[1]
samples_by_type = {s.labels["worker_type"]: s.value for s in up.samples}
assert samples_by_type["primary"] == 1
assert samples_by_type["monitoring"] == 0
samples_by_name = {s.labels["worker"]: s.value for s in up.samples}
assert samples_by_name["primary"] == 1
assert samples_by_name["monitoring"] == 0
def test_empty_monitor_returns_zero(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())

View File

@@ -58,7 +58,8 @@ SERVICE_ORDER=(
validate_template() {
local template_file=$1
echo "Validating template: $template_file..."
if ! aws cloudformation validate-template --template-body file://"$template_file" --region "$AWS_REGION" > /dev/null; then
aws cloudformation validate-template --template-body file://"$template_file" --region "$AWS_REGION" > /dev/null
if [ $? -ne 0 ]; then
echo "Error: Validation failed for $template_file. Exiting."
exit 1
fi
@@ -107,15 +108,13 @@ deploy_stack() {
fi
# Create temporary parameters file for this template
local temp_params_file
temp_params_file=$(create_parameters_from_json "$template_file")
local temp_params_file=$(create_parameters_from_json "$template_file")
# Special handling for SubnetIDs parameter if needed
if grep -q "SubnetIDs" "$template_file"; then
echo "Template uses SubnetIDs parameter, ensuring it's properly formatted..."
# Make sure we're passing SubnetIDs as a comma-separated list
local subnet_ids
subnet_ids=$(remove_comments "$CONFIG_FILE" | jq -r '.SubnetIDs // empty')
local subnet_ids=$(remove_comments "$CONFIG_FILE" | jq -r '.SubnetIDs // empty')
if [ -n "$subnet_ids" ]; then
echo "Using SubnetIDs from config: $subnet_ids"
else
@@ -124,13 +123,15 @@ deploy_stack() {
fi
echo "Deploying stack: $stack_name with template: $template_file and generated config from: $CONFIG_FILE..."
if ! aws cloudformation deploy \
aws cloudformation deploy \
--stack-name "$stack_name" \
--template-file "$template_file" \
--parameter-overrides file://"$temp_params_file" \
--capabilities CAPABILITY_IAM CAPABILITY_NAMED_IAM CAPABILITY_AUTO_EXPAND \
--region "$AWS_REGION" \
--no-cli-auto-prompt > /dev/null; then
--no-cli-auto-prompt > /dev/null
if [ $? -ne 0 ]; then
echo "Error: Deployment failed for $stack_name. Exiting."
exit 1
fi

View File

@@ -52,9 +52,11 @@ delete_stack() {
--region "$AWS_REGION"
echo "Waiting for stack $stack_name to be deleted..."
if aws cloudformation wait stack-delete-complete \
aws cloudformation wait stack-delete-complete \
--stack-name "$stack_name" \
--region "$AWS_REGION"; then
--region "$AWS_REGION"
if [ $? -eq 0 ]; then
echo "Stack $stack_name deleted successfully."
sleep 10
else

Some files were not shown because too many files have changed in this diff Show More