Compare commits

...

60 Commits

Author SHA1 Message Date
Nikolas Garza
546da624a1 feat(metrics): add connector_name label to push-based connector metrics (#10237) 2026-04-15 22:58:49 +00:00
Nikolas Garza
1a88dea760 fix(model-server): add missing onyx/configs to Dockerfile for sentry support (#10236) 2026-04-15 22:42:00 +00:00
Justin Tahara
53d2d647c5 fix(deletion): Commit Session in per-doc cleanup (#10193) 2026-04-15 22:37:00 +00:00
Justin Tahara
560a8f7ab4 feat(mt): Infra setup for Redis Set (1/3) (#10209) 2026-04-15 22:29:49 +00:00
Bo-Onyx
eaabb19c72 fix(pruning): GitHub connector pruning timeout via SlimConnector (#10205) 2026-04-15 22:25:48 +00:00
Bo-Onyx
d3e5e16150 fix(pruning): Resolve hierarchy node FK error for Confluence and Notion (#10235) 2026-04-15 22:25:34 +00:00
Danelegend
d3739611ba feat(connectors): Connectors output TabularSections for tabular files (#10096) 2026-04-15 22:09:28 +00:00
Justin Tahara
73f9a47364 fix(xlsx): Openpyxl Formatting Issues (#10230) 2026-04-15 21:22:58 +00:00
Raunak Bhagat
a808445d96 feat: opalify MessageCard (#10223) 2026-04-15 21:11:18 +00:00
Nikolas Garza
c31215197a fix(chat): hide incomplete citation links during streaming (#10224) 2026-04-15 21:10:06 +00:00
Nikolas Garza
9ebd9ebd73 fix(chat): snap typewriter to full content on tab re-focus (#10226) 2026-04-15 21:07:00 +00:00
Nikolas Garza
f0bb0a6bb0 fix(chat): only header click selects preferred in multi-model panels (#10198) 2026-04-15 21:06:19 +00:00
Ben Wu
01bec19d19 feat(canvas): checkpoint logic (3/4) (#9807) 2026-04-15 20:48:16 +00:00
Danelegend
7b40c2cde7 feat(indexing): CSV Chunker - Field-Value Implementation (#10099) 2026-04-15 19:57:50 +00:00
Jamison Lahman
e2c38d2899 chore(devtools): connect databases and github remote to devcontainer (#10222) 2026-04-15 19:50:11 +00:00
Nikolas Garza
24768f9e4f feat(metrics): replace pull-based connector metrics with push-based for multi-tenant (#10189) 2026-04-15 18:15:34 +00:00
Bo-Onyx
aec1c169b6 feat(pruning): pruning grafana dashboard for single tenant (#10208) 2026-04-15 17:50:28 +00:00
Jamison Lahman
5a16ad3473 chore(tests): avoid openapi client import in tests (#10220) 2026-04-15 17:38:24 +00:00
dependabot[bot]
7e28e59f23 chore(deps): bump transformers from 4.53.0 to 5.5.4 (#9987)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-15 10:39:50 -07:00
Nikolas Garza
879ae6c02d feat(monitoring): add local Prometheus + Grafana docker-compose stack (#9627) 2026-04-15 17:25:28 +00:00
Nikolas Garza
f84f367eb4 fix(voice): send TTS text in POST body instead of query params (#10213) 2026-04-15 17:20:29 +00:00
Jamison Lahman
d81efe3877 fix(ollama): always include model tag in display name (#10218)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-04-15 09:17:37 -07:00
Nikolas Garza
d4619f93c4 feat(indexing): notify admins when connector enters repeated error state (#10207) 2026-04-15 06:10:25 +00:00
Nikolas Garza
70fcfb1d73 feat(indexing): add admin API for failed documents (#10204) 2026-04-15 06:10:06 +00:00
Nikolas Garza
32ba393b32 fix(chat): keep model selector popover open until max models reached (#10203) 2026-04-15 06:09:24 +00:00
Nikolas Garza
f9d2bf78ed fix(chat): disable hover/pointer states on multi-model panels during streaming (#10202) 2026-04-15 06:09:11 +00:00
Nikolas Garza
5567a078fe fix(chat): fix fade gradient missing on last multi-model panel (#10199) 2026-04-15 06:08:48 +00:00
Raunak Bhagat
fc0e8560bc feat: opalify Tooltip component, migrate all consumers (#10210) 2026-04-15 03:42:15 +00:00
Nikolas Garza
60b2701eed feat(indexing): add diagnostic logging to check_for_indexing beat task (#10200) 2026-04-14 20:29:47 -07:00
Jamison Lahman
3682d9844b fix(fe): handle file attachment overflow (#10211) 2026-04-15 02:00:58 +00:00
Raunak Bhagat
a420f9a37c feat: add ref forwarding to input layout components (#10206) 2026-04-15 00:20:50 +00:00
Jamison Lahman
20c5107ba6 chore(devtools): install java runtime into devcontainer (#10197) 2026-04-14 23:10:12 +00:00
Nikolas Garza
357bc91aee feat(indexing): capture swallowed per-doc exceptions in Sentry (#10149) 2026-04-14 23:01:42 +00:00
Nikolas Garza
09653872a2 fix(chat): render inline citation chips in multi-model panels (#10196) 2026-04-14 22:59:10 +00:00
dependabot[bot]
ff01a53f83 chore(deps): bump next from 16.1.7 to 16.2.3 in /web (#10195)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-14 22:49:31 +00:00
Danelegend
03ddd5ca9b feat(indexing): Add TabularSection (#10095) 2026-04-14 22:16:35 +00:00
Bo-Onyx
8c49e4573c fix(pruning): Skip Permission Sync During Google Drive Pruning (#10185) 2026-04-14 22:14:09 +00:00
Jamison Lahman
f1696ffa16 chore(deps): upgrade playwright: 1.55.0->1.58.0 (#10194) 2026-04-14 15:12:14 -07:00
Jamison Lahman
a427cb5b0c chore(deps): upgrade python patch version in docker (#10192) 2026-04-14 15:10:00 -07:00
Evan Lohn
f7e4be18dd fix: uploaded files as knowledge source (#10167) 2026-04-14 21:51:00 +00:00
acaprau
0f31c490fa chore(opensearch): Add debug log for when the migration task releases its lock (#10190) 2026-04-14 14:08:48 -07:00
Wenxi
c9a4a6e42b fix: text shimmer animation nice and fast (#10184) 2026-04-14 20:59:00 +00:00
Nikolas Garza
558c9df3c7 fix(chat): eliminate long-lived DB session in multi-model worker threads (#10159) 2026-04-14 20:37:05 +00:00
Jamison Lahman
30003036d3 chore(fe): Toast logs to the console by default in dev (#10183) 2026-04-14 20:34:04 +00:00
Nikolas Garza
4b2f18c239 fix(chat): speed up text gen (#10186) 2026-04-14 13:41:29 -07:00
Wenxi
4290b097f5 fix: auth logout modal on fresh load (#10007) 2026-04-14 18:43:34 +00:00
Justin Tahara
b0f621a08b fix(llm): Fix the Auto Fetch workflow (#10181) 2026-04-14 18:06:47 +00:00
Raunak Bhagat
112edf41c5 refactor: replace Radix Slot with div wrapper in @opal/core.Disabled (#10119) 2026-04-14 17:40:32 +00:00
SubashMohan
74eb1d7212 feat(notifications): announce upcoming group-based permissions migration (#10178) 2026-04-14 16:23:33 +00:00
dependabot[bot]
e62d592b11 chore(deps): bump alembic from 1.10.4 to 1.18.4 in /backend (#9768)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-14 16:04:58 +00:00
Wenxi
57a0d25321 fix: use static provider list instead of querying be (#10166) 2026-04-14 15:34:57 +00:00
dependabot[bot]
887f79d7a5 chore(deps-dev): bump langchain-core from 1.2.22 to 1.2.28 (#10010)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-14 08:23:30 -07:00
Evan Lohn
65fd1c3ec8 fix: document set name patch (#10162) 2026-04-14 01:53:40 +00:00
Danelegend
6e3ee287b9 fix(files): Lower log level on file store cache miss (#10164) 2026-04-14 01:46:46 +00:00
Raunak Bhagat
dee0b7867e refactor: opalify input layouts with trinary withLabel prop (#10144) 2026-04-14 01:28:37 +00:00
Wenxi
77beb8044e fix(google): handle JSON credential payloads in KV storage (@jack-larch) (#10160)
Co-authored-by: Jack Larch <jack.larch@biograph.com>
2026-04-14 01:20:44 +00:00
Wenxi
750d3ac4ed fix: llm popover should refresh on admin provider edit (#10152) 2026-04-14 01:13:50 +00:00
Bo-Onyx
6c02087ba4 chore(pruning): Add Celery task queue wait time metric (#10161) 2026-04-14 01:08:25 +00:00
Wenxi
0425283ed0 fix: show correct knowledge toggle status on agent edit page (#10151) 2026-04-14 01:07:21 +00:00
Justin Tahara
da97a57c58 feat(metrics): Add Deletion-specific Prometheus Metrics (#10157) 2026-04-14 00:57:16 +00:00
269 changed files with 12086 additions and 3838 deletions

View File

@@ -2,6 +2,7 @@ FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0a
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
default-jre \
fd-find \
fzf \
git \

View File

@@ -1,7 +1,7 @@
{
"name": "Onyx Dev Sandbox",
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW", "--network=onyx_default"],
"mounts": [
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
@@ -12,10 +12,13 @@
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
],
"containerEnv": {
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock"
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock",
"POSTGRES_HOST": "relational_db",
"REDIS_HOST": "cache"
},
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
"updateRemoteUserUID": false,
"initializeCommand": "docker network create onyx_default 2>/dev/null || true",
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
"workspaceFolder": "/workspace",
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",

View File

@@ -4,22 +4,12 @@ set -euo pipefail
echo "Setting up firewall..."
# Preserve docker dns resolution
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
# Flush all rules
iptables -t nat -F
iptables -t nat -X
iptables -t mangle -F
iptables -t mangle -X
# Only flush the filter table. The nat and mangle tables are managed by Docker
# (DNS DNAT to 127.0.0.11, container networking, etc.) and must not be touched —
# flushing them breaks Docker's embedded DNS resolver.
iptables -F
iptables -X
# Restore docker dns rules
if [ -n "$DOCKER_DNS_RULES" ]; then
echo "$DOCKER_DNS_RULES" | iptables-restore -n
fi
# Create ipset for allowed destinations
ipset create allowed-domains hash:net || true
ipset flush allowed-domains
@@ -34,6 +24,7 @@ done
# Resolve allowed domains
ALLOWED_DOMAINS=(
"github.com"
"registry.npmjs.org"
"api.anthropic.com"
"api-staging.anthropic.com"
@@ -65,6 +56,14 @@ if [ -n "$DOCKER_GATEWAY" ]; then
fi
fi
# Allow traffic to all attached Docker network subnets so the container can
# reach sibling services (e.g. relational_db, cache) on shared compose networks.
for subnet in $(ip -4 -o addr show scope global | awk '{print $4}'); do
if ! ipset add allowed-domains "$subnet" -exist 2>&1; then
echo "warning: failed to add Docker subnet $subnet to allowlist" >&2
fi
done
# Set default policies to DROP
iptables -P FORWARD DROP
iptables -P INPUT DROP

12
.vscode/launch.json vendored
View File

@@ -475,6 +475,18 @@
"order": 0
}
},
{
"name": "Start Monitoring Stack (Prometheus + Grafana)",
"type": "node",
"request": "launch",
"runtimeExecutable": "docker",
"runtimeArgs": ["compose", "up", "-d"],
"cwd": "${workspaceFolder}/profiling",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",

View File

@@ -1,4 +1,4 @@
FROM python:3.11.7-slim-bookworm
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
LABEL com.danswer.maintainer="founders@onyx.app"
LABEL com.danswer.description="This image is the web/frontend container of Onyx which \

View File

@@ -1,5 +1,5 @@
# Base stage with dependencies
FROM python:3.11.7-slim-bookworm AS base
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47 AS base
ENV DANSWER_RUNNING_IN_DOCKER="true" \
HF_HOME=/app/.cache/huggingface
@@ -50,6 +50,10 @@ COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
COPY ./onyx/utils/tenant.py /app/onyx/utils/tenant.py
# Sentry configuration (used when SENTRY_DSN is set)
COPY ./onyx/configs/__init__.py /app/onyx/configs/__init__.py
COPY ./onyx/configs/sentry.py /app/onyx/configs/sentry.py
# Place to fetch version information
COPY ./onyx/__init__.py /app/onyx/__init__.py

View File

@@ -208,7 +208,7 @@ def do_run_migrations(
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
@@ -380,7 +380,7 @@ def run_migrations_offline() -> None:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata,
literal_binds=True,
version_table_schema=schema,
include_schemas=True,
@@ -421,7 +421,7 @@ def run_migrations_offline() -> None:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata,
literal_binds=True,
version_table_schema=schema,
include_schemas=True,
@@ -464,7 +464,7 @@ def run_migrations_online() -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,

View File

@@ -25,7 +25,7 @@ def upgrade() -> None:
# Use batch mode to modify the enum type
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
batch_op.alter_column(
"role",
type_=sa.Enum(
"BASIC",
@@ -71,7 +71,7 @@ def downgrade() -> None:
op.drop_column("user__user_group", "is_curator")
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
batch_op.alter_column(
"role",
type_=sa.Enum(
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20

View File

@@ -63,7 +63,7 @@ def upgrade() -> None:
"time_created",
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=False,
existing_server_default=sa.text("now()"), # type: ignore
existing_server_default=sa.text("now()"),
)
op.alter_column(
"index_attempt",
@@ -85,7 +85,7 @@ def downgrade() -> None:
"time_created",
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=True,
existing_server_default=sa.text("now()"), # type: ignore
existing_server_default=sa.text("now()"),
)
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
op.drop_table("accesstoken")

View File

@@ -19,7 +19,7 @@ depends_on: None = None
def upgrade() -> None:
sequence = Sequence("connector_credential_pair_id_seq")
op.execute(CreateSequence(sequence)) # type: ignore
op.execute(CreateSequence(sequence))
op.add_column(
"connector_credential_pair",
sa.Column(

View File

@@ -49,7 +49,7 @@ def run_migrations_offline() -> None:
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
@@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore[arg-type]
target_metadata=target_metadata,
)
with context.begin_transaction():

View File

@@ -96,11 +96,14 @@ def get_model_app() -> FastAPI:
title="Onyx Model Server", version=__version__, lifespan=lifespan
)
if SENTRY_DSN:
from onyx.configs.sentry import _add_instance_tags
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
logger.info("Sentry initialized")
else:

View File

@@ -10,6 +10,7 @@ from celery import bootsteps # type: ignore
from celery import Task
from celery.app import trace
from celery.exceptions import WorkerShutdown
from celery.signals import before_task_publish
from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.states import READY_STATES
@@ -62,11 +63,14 @@ logger = setup_logger()
task_logger = get_task_logger(__name__)
if SENTRY_DSN:
from onyx.configs.sentry import _add_instance_tags
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
logger.info("Sentry initialized")
else:
@@ -94,6 +98,17 @@ class TenantAwareTask(Task):
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@before_task_publish.connect
def on_before_task_publish(
headers: dict[str, Any] | None = None,
**kwargs: Any, # noqa: ARG001
) -> None:
"""Stamp the current wall-clock time into the task message headers so that
workers can compute queue wait time (time between publish and execution)."""
if headers is not None:
headers["enqueued_at"] = time.time()
@task_prerun.connect
def on_task_prerun(
sender: Any | None = None, # noqa: ARG001

View File

@@ -59,6 +59,11 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
from onyx.server.metrics.deletion_metrics import inc_deletion_started
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
@@ -300,6 +305,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
):
inc_deletion_blocked(tenant_id, "indexing")
raise TaskDependencyError(
"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
@@ -307,11 +313,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
if redis_connector.prune.fenced:
inc_deletion_blocked(tenant_id, "pruning")
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
)
if redis_connector.permissions.fenced:
inc_deletion_blocked(tenant_id, "permissions")
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
)
@@ -359,6 +367,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
# set this only after all tasks have been added
fence_payload.num_tasks = tasks_generated
redis_connector.delete.set_fence(fence_payload)
inc_deletion_started(tenant_id)
return tasks_generated
@@ -527,6 +536,12 @@ def monitor_connector_deletion_taskset(
num_docs_synced=fence_data.num_tasks,
)
duration = (
datetime.now(timezone.utc) - fence_data.submitted
).total_seconds()
observe_deletion_taskset_duration(tenant_id, "success", duration)
inc_deletion_completed(tenant_id, "success")
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
@@ -545,6 +560,11 @@ def monitor_connector_deletion_taskset(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
)
duration = (
datetime.now(timezone.utc) - fence_data.submitted
).total_seconds()
observe_deletion_taskset_duration(tenant_id, "failure", duration)
inc_deletion_completed(tenant_id, "failure")
raise e
task_logger.info(
@@ -721,5 +741,6 @@ def validate_connector_deletion_fence(
f"fence={fence_key}"
)
inc_deletion_fence_reset(tenant_id)
redis_connector.delete.reset()
return

View File

@@ -135,10 +135,13 @@ def _docfetching_task(
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
from onyx.configs.sentry import _add_instance_tags
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
logger.info("Sentry initialized")
else:

View File

@@ -3,6 +3,7 @@ import os
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -50,6 +51,7 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NotificationType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -85,6 +87,8 @@ from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.db.notification import create_notification
from onyx.db.notification import get_notifications
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
@@ -105,6 +109,9 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.metrics.connector_health_metrics import on_connector_error_state_change
from onyx.server.metrics.connector_health_metrics import on_connector_indexing_success
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
@@ -400,7 +407,6 @@ def check_indexing_completion(
tenant_id: str,
task: Task,
) -> None:
logger.info(
f"Checking for indexing completion: attempt={index_attempt_id} tenant={tenant_id}"
)
@@ -521,13 +527,25 @@ def check_indexing_completion(
# Update CC pair status if successful
cc_pair = get_connector_credential_pair_from_id(
db_session, attempt.connector_credential_pair_id
db_session,
attempt.connector_credential_pair_id,
eager_load_connector=True,
)
if cc_pair is None:
raise RuntimeError(
f"CC pair {attempt.connector_credential_pair_id} not found in database"
)
source = cc_pair.connector.source.value
connector_name = cc_pair.connector.name or f"cc_pair_{cc_pair.id}"
on_index_attempt_status_change(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
connector_name=connector_name,
status=attempt.status.value,
)
if attempt.status.is_successful():
# NOTE: we define the last successful index time as the time the last successful
# attempt finished. This is distinct from the poll_range_end of the last successful
@@ -548,10 +566,41 @@ def check_indexing_completion(
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
)
on_connector_indexing_success(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
connector_name=connector_name,
docs_indexed=attempt.new_docs_indexed or 0,
success_timestamp=attempt.time_updated.timestamp(),
)
# Clear repeated error state on success
if cc_pair.in_repeated_error_state:
cc_pair.in_repeated_error_state = False
# Delete any existing error notification for this CC pair so a
# fresh one is created if the connector fails again later.
for notif in get_notifications(
user=None,
db_session=db_session,
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
include_dismissed=True,
):
if (
notif.additional_data
and notif.additional_data.get("cc_pair_id") == cc_pair.id
):
db_session.delete(notif)
db_session.commit()
on_connector_error_state_change(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
connector_name=connector_name,
in_error=False,
)
if attempt.status == IndexingStatus.SUCCESS:
logger.info(
@@ -608,6 +657,27 @@ def active_indexing_attempt(
return bool(active_indexing_attempt)
@dataclass
class _KickoffResult:
"""Tracks diagnostic counts from a _kickoff_indexing_tasks run."""
created: int = 0
skipped_active: int = 0
skipped_not_found: int = 0
skipped_not_indexable: int = 0
failed_to_create: int = 0
@property
def evaluated(self) -> int:
return (
self.created
+ self.skipped_active
+ self.skipped_not_found
+ self.skipped_not_indexable
+ self.failed_to_create
)
def _kickoff_indexing_tasks(
celery_app: Celery,
db_session: Session,
@@ -617,12 +687,12 @@ def _kickoff_indexing_tasks(
redis_client: Redis,
lock_beat: RedisLock,
tenant_id: str,
) -> int:
) -> _KickoffResult:
"""Kick off indexing tasks for the given cc_pair_ids and search_settings.
Returns the number of tasks successfully created.
Returns a _KickoffResult with diagnostic counts.
"""
tasks_created = 0
result = _KickoffResult()
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
@@ -633,6 +703,7 @@ def _kickoff_indexing_tasks(
search_settings_id=search_settings.id,
db_session=db_session,
):
result.skipped_active += 1
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -643,6 +714,7 @@ def _kickoff_indexing_tasks(
task_logger.warning(
f"_kickoff_indexing_tasks - CC pair not found: cc_pair={cc_pair_id}"
)
result.skipped_not_found += 1
continue
# Heavyweight check after fetching cc pair
@@ -657,6 +729,7 @@ def _kickoff_indexing_tasks(
f"search_settings={search_settings.id}, "
f"secondary_index_building={secondary_index_building}"
)
result.skipped_not_indexable += 1
continue
task_logger.debug(
@@ -696,13 +769,14 @@ def _kickoff_indexing_tasks(
task_logger.info(
f"Connector indexing queued: index_attempt={attempt_id} cc_pair={cc_pair.id} search_settings={search_settings.id}"
)
tasks_created += 1
result.created += 1
else:
task_logger.error(
f"Failed to create indexing task: cc_pair={cc_pair.id} search_settings={search_settings.id}"
)
result.failed_to_create += 1
return tasks_created
return result
@shared_task(
@@ -728,6 +802,8 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
task_logger.warning("check_for_indexing - Starting")
tasks_created = 0
primary_result = _KickoffResult()
secondary_result: _KickoffResult | None = None
locked = False
redis_client = get_redis_client()
redis_client_replica = get_redis_replica_client()
@@ -848,6 +924,43 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair_id=cc_pair_id,
in_repeated_error_state=True,
)
error_connector_name = (
cc_pair.connector.name or f"cc_pair_{cc_pair.id}"
)
on_connector_error_state_change(
tenant_id=tenant_id,
source=cc_pair.connector.source.value,
cc_pair_id=cc_pair_id,
connector_name=error_connector_name,
in_error=True,
)
connector_name = (
cc_pair.name
or cc_pair.connector.name
or f"CC pair {cc_pair.id}"
)
source = cc_pair.connector.source.value
connector_url = f"/admin/connector/{cc_pair.id}"
create_notification(
user_id=None,
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
db_session=db_session,
title=f"Connector '{connector_name}' has entered repeated error state",
description=(
f"The {source} connector has failed repeatedly and "
f"has been flagged. View indexing history in the "
f"Advanced section: {connector_url}"
),
additional_data={"cc_pair_id": cc_pair.id},
)
task_logger.error(
f"Connector entered repeated error state: "
f"cc_pair={cc_pair.id} "
f"connector={cc_pair.connector.name} "
f"source={source}"
)
# When entering repeated error state, also pause the connector
# to prevent continued indexing retry attempts burning through embedding credits.
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
@@ -863,7 +976,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
with get_session_with_current_tenant() as db_session:
# Primary first
tasks_created += _kickoff_indexing_tasks(
primary_result = _kickoff_indexing_tasks(
celery_app=self.app,
db_session=db_session,
search_settings=current_search_settings,
@@ -873,6 +986,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat=lock_beat,
tenant_id=tenant_id,
)
tasks_created += primary_result.created
# Secondary indexing (only if secondary search settings exist and switchover_type is not INSTANT)
if (
@@ -880,7 +994,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
and secondary_search_settings.switchover_type != SwitchoverType.INSTANT
and secondary_cc_pair_ids
):
tasks_created += _kickoff_indexing_tasks(
secondary_result = _kickoff_indexing_tasks(
celery_app=self.app,
db_session=db_session,
search_settings=secondary_search_settings,
@@ -890,6 +1004,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat=lock_beat,
tenant_id=tenant_id,
)
tasks_created += secondary_result.created
elif (
secondary_search_settings
and secondary_search_settings.switchover_type == SwitchoverType.INSTANT
@@ -1002,7 +1117,26 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
redis_lock_dump(lock_beat, redis_client)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
task_logger.info(
f"check_for_indexing finished: "
f"elapsed={time_elapsed:.2f}s "
f"primary=[evaluated={primary_result.evaluated} "
f"created={primary_result.created} "
f"skipped_active={primary_result.skipped_active} "
f"skipped_not_found={primary_result.skipped_not_found} "
f"skipped_not_indexable={primary_result.skipped_not_indexable} "
f"failed={primary_result.failed_to_create}]"
+ (
f" secondary=[evaluated={secondary_result.evaluated} "
f"created={secondary_result.created} "
f"skipped_active={secondary_result.skipped_active} "
f"skipped_not_found={secondary_result.skipped_not_found} "
f"skipped_not_indexable={secondary_result.skipped_not_indexable} "
f"failed={secondary_result.failed_to_create}]"
if secondary_result
else ""
)
)
return tasks_created

View File

@@ -172,6 +172,10 @@ def migrate_chunks_from_vespa_to_opensearch_task(
search_settings = get_current_search_settings(db_session)
indexing_setting = IndexingSetting.from_db_model(search_settings)
task_logger.debug(
"Verified tenant info, migration record, and search settings."
)
# 2.e. Build sanitized to original doc ID mapping to check for
# conflicts in the event we sanitize a doc ID to an
# already-existing doc ID.
@@ -325,6 +329,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
finally:
if lock.owned():
lock.release()
task_logger.debug("Released the OpenSearch migration lock.")
else:
task_logger.warning(
"The OpenSearch migration lock was not owned on completion of the migration task."

View File

@@ -51,7 +51,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
@@ -643,16 +642,6 @@ def connector_pruning_generator_task(
raw_id_to_parent=all_connector_doc_ids,
)
# Link hierarchy nodes to documents for sources where pages can be
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
all_doc_id_list = list(all_connector_doc_ids.keys())
link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=all_doc_id_list,
source=source,
commit=True,
)
diff_start = time.monotonic()
try:
# a list of docs in our local index

View File

@@ -248,6 +248,7 @@ def document_by_cc_pair_cleanup_task(
),
)
mark_document_as_modified(document_id, db_session)
db_session.commit()
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)

View File

@@ -5,6 +5,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
import sentry_sdk
from celery import Celery
from sqlalchemy.orm import Session
@@ -68,6 +69,7 @@ from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.build.indexing.persistent_document_writer import (
get_persistent_document_writer,
)
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
@@ -267,6 +269,14 @@ def run_docfetching_entrypoint(
)
credential_id = attempt.connector_credential_pair.credential_id
on_index_attempt_status_change(
tenant_id=tenant_id,
source=attempt.connector_credential_pair.connector.source.value,
cc_pair_id=connector_credential_pair_id,
connector_name=connector_name or f"cc_pair_{connector_credential_pair_id}",
status="in_progress",
)
logger.info(
f"Docfetching starting{tenant_str}: "
f"connector='{connector_name}' "
@@ -556,6 +566,27 @@ def connector_document_extraction(
# save record of any failures at the connector level
if failure is not None:
if failure.exception is not None:
with sentry_sdk.new_scope() as scope:
scope.set_tag("stage", "connector_fetch")
scope.set_tag("connector_source", db_connector.source.value)
scope.set_tag("cc_pair_id", str(cc_pair_id))
scope.set_tag("index_attempt_id", str(index_attempt_id))
scope.set_tag("tenant_id", tenant_id)
if failure.failed_document:
scope.set_tag(
"doc_id", failure.failed_document.document_id
)
if failure.failed_entity:
scope.set_tag(
"entity_id", failure.failed_entity.entity_id
)
scope.fingerprint = [
"connector-fetch-failure",
db_connector.source.value,
type(failure.exception).__name__,
]
sentry_sdk.capture_exception(failure.exception)
total_failures += 1
with get_session_with_current_tenant() as db_session:
create_index_attempt_error(

View File

@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
plaintext_io = file_store.read_file(plaintext_key, mode="b")
return plaintext_io.read().decode("utf-8")
except Exception:
logger.exception(f"Error when reading file, id={file_id}")
logger.info(f"Cache miss for file with id={file_id}")
# Cache miss — extract and store.
content_text = extract_fn()

View File

@@ -4,8 +4,6 @@ from collections.abc import Callable
from typing import Any
from typing import Literal
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import create_tool_call_failure_messages
from onyx.chat.citation_processor import CitationMapping
@@ -635,7 +633,6 @@ def run_llm_loop(
user_memory_context: UserMemoryContext | None,
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
forced_tool_id: int | None = None,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
@@ -1020,20 +1017,16 @@ def run_llm_loop(
persisted_memory_id: int | None = None
if user_memory_context and user_memory_context.user_id:
if tool_response.rich_response.index_to_replace is not None:
memory = update_memory_at_index(
persisted_memory_id = update_memory_at_index(
user_id=user_memory_context.user_id,
index=tool_response.rich_response.index_to_replace,
new_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id if memory else None
else:
memory = add_memory(
persisted_memory_id = add_memory(
user_id=user_memory_context.user_id,
memory_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id
operation: Literal["add", "update"] = (
"update"
if tool_response.rich_response.index_to_replace is not None

View File

@@ -67,7 +67,6 @@ from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
@@ -1006,93 +1005,86 @@ def _run_models(
model_llm = setup.llms[model_idx]
try:
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
# Do NOT write to the outer db_session (or any shared DB state) from here;
# all DB writes in this thread must go through thread_db_session.
with get_session_with_current_tenant() as thread_db_session:
thread_tool_dict = construct_tools(
persona=setup.persona,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=setup.new_msg_req.internal_search_filters,
project_id_filter=setup.search_params.project_id_filter,
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
# Each function opens short-lived DB sessions on demand.
# Do NOT pass a long-lived session here — it would hold a
# connection for the entire LLM loop (minutes), and cloud
# infrastructure may drop idle connections.
thread_tool_dict = construct_tools(
persona=setup.persona,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=setup.new_msg_req.internal_search_filters,
project_id_filter=setup.search_params.project_id_filter,
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
message_id=setup.user_message.id,
additional_headers=setup.custom_tool_additional_headers,
mcp_headers=setup.mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=setup.available_files.user_file_ids,
chat_file_ids=setup.available_files.chat_file_ids,
),
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
message_id=setup.user_message.id,
additional_headers=setup.custom_tool_additional_headers,
mcp_headers=setup.mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=setup.available_files.user_file_ids,
chat_file_ids=setup.available_files.chat_file_ids,
),
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
)
model_tools = [
tool for tool_list in thread_tool_dict.values() for tool in tool_list
]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
model_tools = [
tool
for tool_list in thread_tool_dict.values()
for tool in tool_list
]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError(
"Deep research is not supported for projects"
)
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
skip_clarification=setup.skip_clarification,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
all_injected_file_metadata=setup.all_injected_file_metadata,
)
else:
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
context_files=setup.extracted_context_files,
persona=setup.persona,
user_memory_context=setup.user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=setup.forced_tool_id,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
chat_files=setup.chat_files_for_tools,
include_citations=setup.new_msg_req.include_citations,
all_injected_file_metadata=setup.all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError("Deep research is not supported for projects")
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
skip_clarification=setup.skip_clarification,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
all_injected_file_metadata=setup.all_injected_file_metadata,
)
else:
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
context_files=setup.extracted_context_files,
persona=setup.persona,
user_memory_context=setup.user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
forced_tool_id=setup.forced_tool_id,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
chat_files=setup.chat_files_for_tools,
include_citations=setup.new_msg_req.include_citations,
all_injected_file_metadata=setup.all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True

View File

@@ -1125,6 +1125,32 @@ DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
# Number of pre-provisioned tenants to maintain
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
# Master switch for the tenant work-gating feature. Controls the `enabled`
# axis only — flipping this True puts the feature in shadow mode (compute
# the gate, log skip counts, but do not actually skip). The `enforce` axis
# is Redis-only with a hard-coded default of False, so this env flag alone
# cannot cause real tenants to be skipped. Default off.
ENABLE_TENANT_WORK_GATING = (
os.environ.get("ENABLE_TENANT_WORK_GATING", "").lower() == "true"
)
# Membership TTL for the `active_tenants` sorted set. Members older than this
# are treated as inactive by the gate read path. Must be > the full-fanout
# interval so self-healing re-adds a genuinely-working tenant before their
# membership expires. Default 30 min.
TENANT_WORK_GATING_TTL_SECONDS = int(
os.environ.get("TENANT_WORK_GATING_TTL_SECONDS", 30 * 60)
)
# Minimum wall-clock interval between full-fanout cycles. When this many
# seconds have elapsed since the last bypass, the generator ignores the gate
# on the next invocation and dispatches to every non-gated tenant, letting
# consumers re-populate the active set. Schedule-independent so beat drift
# or backlog can't make the self-heal bursty or sparse. Default 20 min.
TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS = int(
os.environ.get("TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS", 20 * 60)
)
# Image summarization configuration
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(

View File

@@ -283,6 +283,7 @@ class NotificationType(str, Enum):
RELEASE_NOTES = "release_notes"
ASSISTANT_FILES_READY = "assistant_files_ready"
FEATURE_ANNOUNCEMENT = "feature_announcement"
CONNECTOR_REPEATED_ERRORS = "connector_repeated_errors"
class BlobType(str, Enum):

View File

@@ -0,0 +1,48 @@
from typing import Any
from sentry_sdk.types import Event
from onyx.utils.logger import setup_logger
logger = setup_logger()
_instance_id_resolved = False
def _add_instance_tags(
event: Event,
hint: dict[str, Any], # noqa: ARG001
) -> Event | None:
"""Sentry before_send hook that lazily attaches instance identification tags.
On the first event, resolves the instance UUID from the KV store (requires DB)
and sets it as a global Sentry tag. Subsequent events pick it up automatically.
"""
global _instance_id_resolved
if _instance_id_resolved:
return event
try:
import sentry_sdk
from shared_configs.configs import MULTI_TENANT
if MULTI_TENANT:
instance_id = "multi-tenant-cloud"
else:
from onyx.utils.telemetry import get_or_generate_uuid
instance_id = get_or_generate_uuid()
sentry_sdk.set_tag("instance_id", instance_id)
# Also set on this event since set_tag won't retroactively apply
event.setdefault("tags", {})["instance_id"] = instance_id
# Only mark resolved after success — if DB wasn't ready, retry next event
_instance_id_resolved = True
except Exception:
logger.debug("Failed to resolve instance_id for Sentry tagging")
return event

View File

@@ -26,6 +26,10 @@ from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
process_onyx_metadata,
)
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
@@ -38,6 +42,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -451,6 +456,40 @@ class BlobStorageConnector(LoadConnector, PollConnector):
logger.exception(f"Error processing image {key}")
continue
# Handle tabular files (xlsx, csv, tsv) — produce one
# TabularSection per sheet (or per file for csv/tsv)
# instead of a flat TextSection.
if is_tabular_file(file_name):
try:
downloaded_file = self._download_object(key)
if downloaded_file is None:
continue
tabular_sections = tabular_file_to_sections(
BytesIO(downloaded_file),
file_name=file_name,
link=link,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=(
tabular_sections
if tabular_sections
else [TabularSection(link=link, text="")]
),
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logger.exception(f"Error processing tabular file {key}")
continue
# Handle text and document files
try:
downloaded_file = self._download_object(key)

View File

@@ -27,16 +27,19 @@ _STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
404: OnyxErrorCode.BAD_GATEWAY,
429: OnyxErrorCode.RATE_LIMITED,
}
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
"""Map an HTTP status code to the appropriate OnyxErrorCode.
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
Expects a >= 400 status code. Known codes (401, 403, 404) are
mapped to specific error codes; all other codes (unrecognised 4xx
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
Note: 429 is intentionally omitted — the rl_requests wrapper
handles rate limits transparently at the HTTP layer, so 429
responses never reach this function.
"""
if status_code in _STATUS_TO_ERROR_CODE:
return _STATUS_TO_ERROR_CODE[status_code]

View File

@@ -1,10 +1,9 @@
from datetime import datetime
from datetime import timezone
from enum import StrEnum
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
@@ -25,8 +24,11 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
@@ -47,10 +49,6 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
@@ -61,6 +59,60 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
)
class CanvasStage(StrEnum):
PAGES = "pages"
ASSIGNMENTS = "assignments"
ANNOUNCEMENTS = "announcements"
_STAGE_CONFIG: dict[CanvasStage, dict[str, Any]] = {
CanvasStage.PAGES: {
"endpoint": "courses/{course_id}/pages",
"params": {
"per_page": "100",
"include[]": "body",
"published": "true",
"sort": "updated_at",
"order": "desc",
},
},
CanvasStage.ASSIGNMENTS: {
"endpoint": "courses/{course_id}/assignments",
"params": {"per_page": "100", "published": "true"},
},
CanvasStage.ANNOUNCEMENTS: {
"endpoint": "announcements",
"params": {
"per_page": "100",
"context_codes[]": "course_{course_id}",
"active_only": "true",
},
},
}
def _parse_canvas_dt(timestamp_str: str) -> datetime:
"""Parse a Canvas ISO-8601 timestamp (e.g. '2025-06-15T12:00:00Z')
into a timezone-aware UTC datetime.
Canvas returns timestamps with a trailing 'Z' instead of '+00:00',
so we normalise before parsing.
"""
return datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")).astimezone(
timezone.utc
)
def _unix_to_canvas_time(epoch: float) -> str:
"""Convert a Unix timestamp to Canvas ISO-8601 format (e.g. '2025-06-15T12:00:00Z')."""
return datetime.fromtimestamp(epoch, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _in_time_window(timestamp_str: str, start: float, end: float) -> bool:
"""Check whether a Canvas ISO-8601 timestamp falls within (start, end]."""
return start < _parse_canvas_dt(timestamp_str).timestamp() <= end
class CanvasCourse(BaseModel):
id: int
name: str | None = None
@@ -145,9 +197,6 @@ class CanvasAnnouncement(BaseModel):
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
"""Checkpoint state for resumable Canvas indexing.
@@ -165,15 +214,30 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
course_ids: list[int] = []
current_course_index: int = 0
stage: CanvasStage = "pages"
stage: CanvasStage = CanvasStage.PAGES
next_url: str | None = None
def advance_course(self) -> None:
"""Move to the next course and reset within-course state."""
self.current_course_index += 1
self.stage = "pages"
self.stage = CanvasStage.PAGES
self.next_url = None
def advance_stage(self) -> None:
"""Advance past the current stage.
Moves to the next stage within the same course, or to the next
course if the current stage is the last one. Resets next_url so
the next call starts fresh on the new stage.
"""
self.next_url = None
stages: list[CanvasStage] = list(CanvasStage)
next_idx = stages.index(self.stage) + 1
if next_idx < len(stages):
self.stage = stages[next_idx]
else:
self.advance_course()
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
@@ -295,13 +359,7 @@ class CanvasConnector(
if body_text:
text_parts.append(body_text)
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
doc_updated_at = _parse_canvas_dt(page.updated_at) if page.updated_at else None
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
@@ -325,17 +383,11 @@ class CanvasConnector(
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
due_dt = _parse_canvas_dt(assignment.due_at)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
_parse_canvas_dt(assignment.updated_at) if assignment.updated_at else None
)
document = self._build_document(
@@ -361,11 +413,7 @@ class CanvasConnector(
text_parts.append(msg_text)
doc_updated_at = (
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
_parse_canvas_dt(announcement.posted_at) if announcement.posted_at else None
)
document = self._build_document(
@@ -400,6 +448,314 @@ class CanvasConnector(
self._canvas_client = client
return None
def _fetch_stage_page(
self,
next_url: str | None,
endpoint: str,
params: dict[str, Any],
) -> tuple[list[Any], str | None]:
"""Fetch one page of API results for the current stage.
Returns (items, next_url). All error handling is done by the
caller (_load_from_checkpoint).
"""
if next_url:
# Resuming mid-pagination: the next_url from Canvas's
# Link header already contains endpoint + query params.
response, result_next_url = self.canvas_client.get(full_url=next_url)
else:
# First request for this stage: build from endpoint + params.
response, result_next_url = self.canvas_client.get(
endpoint=endpoint, params=params
)
return response or [], result_next_url
def _process_items(
self,
response: list[Any],
stage: CanvasStage,
course_id: int,
start: float,
end: float,
include_permissions: bool,
) -> tuple[list[Document | ConnectorFailure], bool]:
"""Process a page of API results into documents.
Returns (docs, early_exit). early_exit is True when pages
(sorted desc by updated_at) hit an item older than start,
signaling that pagination should stop.
"""
results: list[Document | ConnectorFailure] = []
early_exit = False
for item in response:
try:
if stage == CanvasStage.PAGES:
page = CanvasPage.from_api(item, course_id=course_id)
if not page.updated_at:
continue
# Pages are sorted by updated_at desc — once we see
# an item at or before `start`, all remaining items
# on this and subsequent pages are older too.
if not _in_time_window(page.updated_at, start, end):
if _parse_canvas_dt(page.updated_at).timestamp() <= start:
early_exit = True
break
# ts > end: page is newer than our window, skip it
continue
doc = self._convert_page_to_document(page)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
elif stage == CanvasStage.ASSIGNMENTS:
assignment = CanvasAssignment.from_api(item, course_id=course_id)
if not assignment.updated_at or not _in_time_window(
assignment.updated_at, start, end
):
continue
doc = self._convert_assignment_to_document(assignment)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
elif stage == CanvasStage.ANNOUNCEMENTS:
announcement = CanvasAnnouncement.from_api(
item, course_id=course_id
)
if not announcement.posted_at:
logger.debug(
f"Skipping announcement {announcement.id} in "
f"course {course_id}: no posted_at"
)
continue
if not _in_time_window(announcement.posted_at, start, end):
continue
doc = self._convert_announcement_to_document(announcement)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
except Exception as e:
item_id = item.get("id") or item.get("page_id", "unknown")
if stage == CanvasStage.PAGES:
doc_link = (
f"{self.canvas_base_url}/courses/{course_id}"
f"/pages/{item.get('url', '')}"
)
else:
doc_link = item.get("html_url", "")
results.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"canvas-{stage.removesuffix('s')}-{course_id}-{item_id}",
document_link=doc_link,
),
failure_message=f"Failed to process {stage.removesuffix('s')}: {e}",
exception=e,
)
)
return results, early_exit
def _maybe_attach_permissions(
self,
document: Document,
course_id: int,
include_permissions: bool,
) -> Document:
if include_permissions:
document.external_access = self._get_course_permissions(course_id)
return document
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
include_permissions: bool = False,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
"""Shared implementation for load_from_checkpoint and load_from_checkpoint_with_perm_sync."""
new_checkpoint = checkpoint.model_copy(deep=True)
# First call: materialize the list of course IDs.
# On failure, let the exception propagate so the framework fails the
# attempt cleanly. Swallowing errors here would leave the checkpoint
# state unchanged and cause an infinite retry loop.
if not new_checkpoint.course_ids:
try:
courses = self._list_courses()
except OnyxError as e:
if e.status_code in (401, 403):
_handle_canvas_api_error(e) # NoReturn — always raises
raise
new_checkpoint.course_ids = [c.id for c in courses]
logger.info(f"Found {len(courses)} Canvas courses to process")
new_checkpoint.has_more = len(new_checkpoint.course_ids) > 0
return new_checkpoint
# All courses done.
if new_checkpoint.current_course_index >= len(new_checkpoint.course_ids):
new_checkpoint.has_more = False
return new_checkpoint
course_id = new_checkpoint.course_ids[new_checkpoint.current_course_index]
try:
stage = CanvasStage(new_checkpoint.stage)
except ValueError as e:
raise ValueError(
f"Invalid checkpoint stage: {new_checkpoint.stage!r}. "
f"Valid stages: {[s.value for s in CanvasStage]}"
) from e
# Build endpoint + params from the static template.
config = _STAGE_CONFIG[stage]
endpoint = config["endpoint"].format(course_id=course_id)
params = {k: v.format(course_id=course_id) for k, v in config["params"].items()}
# Only the announcements API supports server-side date filtering
# (start_date/end_date). Pages support server-side sorting
# (sort=updated_at desc) enabling early exit, but not date
# filtering. Assignments support neither. Both are filtered
# client-side via _in_time_window after fetching.
if stage == CanvasStage.ANNOUNCEMENTS:
params["start_date"] = _unix_to_canvas_time(start)
params["end_date"] = _unix_to_canvas_time(end)
try:
response, result_next_url = self._fetch_stage_page(
next_url=new_checkpoint.next_url,
endpoint=endpoint,
params=params,
)
except OnyxError as oe:
# Security errors from _parse_next_link (host/scheme
# mismatch on pagination URLs) have no status code override
# and must not be silenced.
is_api_error = oe._status_code_override is not None
if not is_api_error:
raise
if oe.status_code in (401, 403):
_handle_canvas_api_error(oe) # NoReturn — always raises
# 404 means the course itself is gone or inaccessible. The
# other stages on this course will hit the same 404, so skip
# the whole course rather than burning API calls on each stage.
if oe.status_code == 404:
logger.warning(
f"Canvas course {course_id} not found while fetching "
f"{stage} (HTTP 404). Skipping course."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-course-{course_id}",
),
failure_message=(f"Canvas course {course_id} not found: {oe}"),
exception=oe,
)
new_checkpoint.advance_course()
else:
logger.warning(
f"Failed to fetch {stage} for course {course_id}: {oe}. "
f"Skipping remainder of this stage."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-{stage}-{course_id}",
),
failure_message=(
f"Failed to fetch {stage} for course {course_id}: {oe}"
),
exception=oe,
)
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
except Exception as e:
# Unknown error — skip the stage and try to continue.
logger.warning(
f"Failed to fetch {stage} for course {course_id}: {e}. "
f"Skipping remainder of this stage."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-{stage}-{course_id}",
),
failure_message=(
f"Failed to fetch {stage} for course {course_id}: {e}"
),
exception=e,
)
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
# Process fetched items
results, early_exit = self._process_items(
response, stage, course_id, start, end, include_permissions
)
for result in results:
yield result
# If we hit an item older than our window (pages sorted desc),
# skip remaining pagination and advance to the next stage.
if early_exit:
result_next_url = None
# If there are more pages, save the cursor and return
if result_next_url:
new_checkpoint.next_url = result_next_url
else:
# Stage complete — advance to next stage (or next course if last).
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=False
)
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
"""Load documents from checkpoint with permission information included."""
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=True
)
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
return CanvasConnectorCheckpoint(has_more=True)
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
return CanvasConnectorCheckpoint.model_validate_json(checkpoint_json)
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
@@ -415,38 +771,6 @@ class CanvasConnector(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,

View File

@@ -171,7 +171,10 @@ class ClickupConnector(LoadConnector, PollConnector):
document.metadata[extra_field] = task[extra_field]
if self.retrieve_task_comments:
document.sections.extend(self._get_task_comments(task["id"]))
document.sections = [
*document.sections,
*self._get_task_comments(task["id"]),
]
doc_batch.append(document)

View File

@@ -0,0 +1,69 @@
import csv
import io
from typing import IO
from onyx.connectors.models import TabularSection
from onyx.file_processing.extract_file_text import file_io_to_text
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.utils.logger import setup_logger
logger = setup_logger()
def is_tabular_file(file_name: str) -> bool:
lowered = file_name.lower()
return any(lowered.endswith(ext) for ext in OnyxFileExtensions.TABULAR_EXTENSIONS)
def _tsv_to_csv(tsv_text: str) -> str:
"""Re-serialize tab-separated text as CSV so downstream parsers that
assume the default Excel dialect read the columns correctly."""
out = io.StringIO()
csv.writer(out, lineterminator="\n").writerows(
csv.reader(io.StringIO(tsv_text), dialect="excel-tab")
)
return out.getvalue().rstrip("\n")
def tabular_file_to_sections(
file: IO[bytes],
file_name: str,
link: str = "",
) -> list[TabularSection]:
"""Convert a tabular file into one or more TabularSections.
- .xlsx → one TabularSection per non-empty sheet.
- .csv / .tsv → a single TabularSection containing the full decoded
file.
Returns an empty list when the file yields no extractable content.
"""
lowered = file_name.lower()
if lowered.endswith(".xlsx"):
return [
TabularSection(
link=link or file_name,
text=csv_text,
heading=f"{file_name} :: {sheet_title}",
)
for csv_text, sheet_title in xlsx_sheet_extraction(
file, file_name=file_name
)
]
if not lowered.endswith((".csv", ".tsv")):
raise ValueError(f"{file_name!r} is not a tabular file")
try:
text = file_io_to_text(file).strip()
except Exception:
logger.exception(f"Failure decoding {file_name}")
raise
if not text:
return []
if lowered.endswith(".tsv"):
text = _tsv_to_csv(text)
return [TabularSection(link=link or file_name, text=text)]

View File

@@ -15,6 +15,10 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
)
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rate_limit_builder
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rl_requests
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.drupal_wiki.models import DrupalWikiCheckpoint
from onyx.connectors.drupal_wiki.models import DrupalWikiPage
from onyx.connectors.drupal_wiki.models import DrupalWikiPageResponse
@@ -33,6 +37,7 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -213,7 +218,7 @@ class DrupalWikiConnector(
attachment: dict[str, Any],
page_id: int,
download_url: str,
) -> tuple[list[TextSection | ImageSection], str | None]:
) -> tuple[list[TextSection | ImageSection | TabularSection], str | None]:
"""
Process a single attachment and return generated sections.
@@ -226,7 +231,7 @@ class DrupalWikiConnector(
Tuple of (sections, error_message). If error_message is not None, the
sections list should be treated as invalid.
"""
sections: list[TextSection | ImageSection] = []
sections: list[TextSection | ImageSection | TabularSection] = []
try:
if not self._validate_attachment_filetype(attachment):
@@ -273,6 +278,25 @@ class DrupalWikiConnector(
return sections, None
# Tabular attachments (xlsx, csv, tsv) — produce
# TabularSections instead of a flat TextSection.
if is_tabular_file(file_name):
try:
sections.extend(
tabular_file_to_sections(
BytesIO(raw_bytes),
file_name=file_name,
link=download_url,
)
)
except Exception:
logger.exception(
f"Failed to extract tabular sections from {file_name}"
)
if not sections:
return [], f"No content extracted from tabular file {file_name}"
return sections, None
image_counter = 0
def _store_embedded_image(image_data: bytes, image_name: str) -> None:
@@ -497,7 +521,7 @@ class DrupalWikiConnector(
page_url = build_drupal_wiki_document_id(self.base_url, page.id)
# Create sections with just the page content
sections: list[TextSection | ImageSection] = [
sections: list[TextSection | ImageSection | TabularSection] = [
TextSection(text=text_content, link=page_url)
]

View File

@@ -2,6 +2,7 @@ import json
import os
from datetime import datetime
from datetime import timezone
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import IO
@@ -12,11 +13,16 @@ from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
process_onyx_metadata,
)
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -179,8 +185,32 @@ def _process_file(
link = onyx_metadata.link or link
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection] = []
if extraction_result.text_content.strip():
sections: list[TextSection | ImageSection | TabularSection] = []
if is_tabular_file(file_name):
# Produce TabularSections
lowered_name = file_name.lower()
if lowered_name.endswith(".xlsx"):
file.seek(0)
tabular_source: IO[bytes] = file
else:
tabular_source = BytesIO(
extraction_result.text_content.encode("utf-8", errors="replace")
)
try:
sections.extend(
tabular_file_to_sections(
file=tabular_source,
file_name=file_name,
link=link or "",
)
)
except Exception as e:
logger.error(f"Failed to process tabular file {file_name}: {e}")
return []
if not sections:
logger.warning(f"No content extracted from tabular file {file_name}")
return []
elif extraction_result.text_content.strip():
logger.debug(f"Creating TextSection for {file_name} with link: {link}")
sections.append(
TextSection(link=link, text=extraction_result.text_content.strip())

View File

@@ -22,6 +22,7 @@ from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.constants import DocumentSource
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@@ -35,10 +36,16 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import IndexingHeartbeatInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -427,7 +434,11 @@ def make_cursor_url_callback(
return cursor_url_callback
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
class GithubConnector(
CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint],
SlimConnector,
SlimConnectorWithPermSync,
):
def __init__(
self,
repo_owner: str,
@@ -559,6 +570,7 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
start: datetime | None = None,
end: datetime | None = None,
include_permissions: bool = False,
is_slim: bool = False,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
@@ -614,36 +626,46 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
for pr in pr_batch:
num_prs += 1
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
yield _convert_pr_to_document(
cast(PullRequest, pr), repo_external_access
if is_slim:
yield Document(
id=pr.html_url,
sections=[],
external_access=repo_external_access,
source=DocumentSource.GITHUB,
semantic_identifier="",
metadata={},
)
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
else:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
yield _convert_pr_to_document(
cast(PullRequest, pr), repo_external_access
)
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
# If we reach this point with a cursor url in the checkpoint, we were using
# the fallback cursor-based pagination strategy. That strategy tries to get all
@@ -689,38 +711,47 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
for issue in issue_batch:
num_issues += 1
issue = cast(Issue, issue)
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
yield _convert_issue_to_document(issue, repo_external_access)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
if is_slim:
yield Document(
id=issue.html_url,
sections=[],
external_access=repo_external_access,
source=DocumentSource.GITHUB,
semantic_identifier="",
metadata={},
)
continue
else:
# we iterate backwards in time, so at this point we stop processing issues
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_issues = True
break
# Skip issues updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
yield _convert_issue_to_document(issue, repo_external_access)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
logger.info(f"Fetched {num_issues} issues for repo: {repo.name}")
# if we found any issues on the page, and we're not done, return the checkpoint.
@@ -803,6 +834,60 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
start, end, checkpoint, include_permissions=True
)
def _retrieve_slim_docs(
self,
include_permissions: bool,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""Iterate all PRs and issues across all configured repos as SlimDocuments.
Drives _fetch_from_github in a checkpoint loop — each call processes one
page and returns an updated checkpoint. CheckpointOutputWrapper handles
draining the generator and extracting the returned checkpoint. Rate
limiting and pagination are handled centrally by _fetch_from_github via
_get_batch_rate_limited.
"""
checkpoint = self.build_dummy_checkpoint()
while checkpoint.has_more:
batch: list[SlimDocument | HierarchyNode] = []
gen = self._fetch_from_github(
checkpoint, include_permissions=include_permissions, is_slim=True
)
wrapper: CheckpointOutputWrapper[GithubConnectorCheckpoint] = (
CheckpointOutputWrapper()
)
for document, _, _, next_checkpoint in wrapper(gen):
if document is not None:
batch.append(
SlimDocument(
id=document.id, external_access=document.external_access
)
)
if next_checkpoint is not None:
checkpoint = next_checkpoint
if batch:
yield batch
if callback and callback.should_stop():
raise RuntimeError("github_slim_docs: Stop signal detected")
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_slim_docs(include_permissions=False, callback=callback)
@override
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_slim_docs(include_permissions=True, callback=callback)
def validate_connector_settings(self) -> None:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")

View File

@@ -75,6 +75,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import NormalizationResult
from onyx.connectors.interfaces import Resolver
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -207,6 +208,7 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnector,
SlimConnectorWithPermSync,
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
Resolver,
@@ -1754,6 +1756,7 @@ class GoogleDriveConnector(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
files_batch: list[RetrievedDriveFile] = []
slim_batch: list[SlimDocument | HierarchyNode] = []
@@ -1763,9 +1766,13 @@ class GoogleDriveConnector(
nonlocal files_batch, slim_batch
# Get new ancestor hierarchy nodes first
permission_sync_context = PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
permission_sync_context = (
PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
)
if include_permissions
else None
)
new_ancestors = self._get_new_ancestors_for_files(
files=files_batch,
@@ -1779,10 +1786,7 @@ class GoogleDriveConnector(
if doc := build_slim_document(
self.creds,
file.drive_file,
PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
),
permission_sync_context,
retriever_email=file.user_email,
):
slim_batch.append(doc)
@@ -1822,11 +1826,12 @@ class GoogleDriveConnector(
if files_batch:
yield _yield_slim_batch()
def retrieve_all_slim_docs_perm_sync(
def _retrieve_all_slim_docs_impl(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
try:
checkpoint = self.build_dummy_checkpoint()
@@ -1836,13 +1841,34 @@ class GoogleDriveConnector(
start=start,
end=end,
callback=callback,
include_permissions=include_permissions,
)
logger.info("Drive perm sync: Slim doc retrieval complete")
logger.info("Drive slim doc retrieval complete")
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
raise
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_all_slim_docs_impl(
start=start, end=end, callback=callback, include_permissions=False
)
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_all_slim_docs_impl(
start=start, end=end, callback=callback, include_permissions=True
)
def validate_connector_settings(self) -> None:
if self._creds is None:

View File

@@ -13,6 +13,10 @@ from pydantic import BaseModel
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.models import GDriveMimeType
@@ -28,15 +32,16 @@ from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_docx_file
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.file_types import SPREADSHEET_MIME_TYPE
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
@@ -289,7 +294,7 @@ def _download_and_extract_sections_basic(
service: GoogleDriveService,
allow_images: bool,
size_threshold: int,
) -> list[TextSection | ImageSection]:
) -> list[TextSection | ImageSection | TabularSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
@@ -308,7 +313,7 @@ def _download_and_extract_sections_basic(
return []
# Store images for later processing
sections: list[TextSection | ImageSection] = []
sections: list[TextSection | ImageSection | TabularSection] = []
try:
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
@@ -323,10 +328,9 @@ def _download_and_extract_sections_basic(
logger.error(f"Failed to process image {file_name}: {e}")
return sections
# For Google Docs, Sheets, and Slides, export as plain text
# For Google Docs, Sheets, and Slides, export via the Drive API
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
@@ -335,6 +339,17 @@ def _download_and_extract_sections_basic(
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
if export_mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
# Synthesize an extension on the filename
ext = ".xlsx" if export_mime_type == SPREADSHEET_MIME_TYPE else ".csv"
return list(
tabular_file_to_sections(
io.BytesIO(response),
file_name=f"{file_name}{ext}",
link=link,
)
)
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
@@ -356,9 +371,15 @@ def _download_and_extract_sections_basic(
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
or is_tabular_file(file_name)
):
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
return list(
tabular_file_to_sections(
io.BytesIO(response_call()),
file_name=file_name,
link=link,
)
)
elif (
mime_type
@@ -369,7 +390,7 @@ def _download_and_extract_sections_basic(
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [
pdf_sections: list[TextSection | ImageSection | TabularSection] = [
TextSection(link=link, text=text)
]
@@ -410,8 +431,9 @@ def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
def align_basic_advanced(
basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]
) -> list[TextSection | ImageSection]:
basic_sections: list[TextSection | ImageSection | TabularSection],
adv_sections: list[TextSection],
) -> list[TextSection | ImageSection | TabularSection]:
"""Align the basic sections with the advanced sections.
In particular, the basic sections contain all content of the file,
including smart chips like dates and doc links. The advanced sections
@@ -428,7 +450,7 @@ def align_basic_advanced(
basic_full_text = "".join(
[section.text for section in basic_sections if isinstance(section, TextSection)]
)
new_sections: list[TextSection | ImageSection] = []
new_sections: list[TextSection | ImageSection | TabularSection] = []
heading_start = 0
for adv_ind in range(1, len(adv_sections)):
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
@@ -599,7 +621,7 @@ def _convert_drive_item_to_document(
"""
Main entry point for converting a Google Drive file => Document object.
"""
sections: list[TextSection | ImageSection] = []
sections: list[TextSection | ImageSection | TabularSection] = []
# Only construct these services when needed
def _get_drive_service() -> GoogleDriveService:
@@ -639,7 +661,9 @@ def _convert_drive_item_to_document(
doc_id=file.get("id", ""),
)
if doc_sections:
sections = cast(list[TextSection | ImageSection], doc_sections)
sections = cast(
list[TextSection | ImageSection | TabularSection], doc_sections
)
if any(SMART_CHIP_CHAR in section.text for section in doc_sections):
logger.debug(
f"found smart chips in {file.get('name')}, aligning with basic sections"

View File

@@ -1,4 +1,5 @@
import json
from typing import Any
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
@@ -53,6 +54,21 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _load_google_json(raw: object) -> dict[str, Any]:
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
Payloads written before the fix for serializing Google credentials into
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
Once every install has re-uploaded their Google credentials the legacy
``str`` branch can be removed.
"""
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
return json.loads(raw)
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
@@ -162,12 +178,13 @@ def build_service_account_creds(
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
credential_json = _load_google_json(
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
)
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=GOOGLE_SCOPES[source],
@@ -188,12 +205,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleAppCredentials(**json.loads(creds_str))
return GoogleAppCredentials(**creds)
def upsert_google_app_cred(
@@ -201,10 +218,14 @@ def upsert_google_app_cred(
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
KV_GOOGLE_DRIVE_CRED_KEY,
app_credentials.model_dump(mode="json"),
encrypt=True,
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
get_kv_store().store(
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
)
else:
raise ValueError(f"Unsupported source: {source}")
@@ -220,12 +241,14 @@ def delete_google_app_cred(source: DocumentSource) -> None:
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
creds = _load_google_json(
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
)
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleServiceAccountKey(**json.loads(creds_str))
return GoogleServiceAccountKey(**creds)
def upsert_service_account_key(
@@ -234,12 +257,14 @@ def upsert_service_account_key(
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
service_account_key.json(),
service_account_key.model_dump(mode="json"),
encrypt=True,
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
KV_GMAIL_SERVICE_ACCOUNT_KEY,
service_account_key.model_dump(mode="json"),
encrypt=True,
)
else:
raise ValueError(f"Unsupported source: {source}")

View File

@@ -123,6 +123,9 @@ class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -1,4 +1,5 @@
import sys
from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from typing import Any
@@ -39,6 +40,7 @@ class SectionType(str, Enum):
TEXT = "text"
IMAGE = "image"
TABULAR = "tabular"
class Section(BaseModel):
@@ -48,6 +50,7 @@ class Section(BaseModel):
link: str | None = None
text: str | None = None
image_file_id: str | None = None
heading: str | None = None
class TextSection(Section):
@@ -70,6 +73,18 @@ class ImageSection(Section):
return sys.getsizeof(self.image_file_id) + sys.getsizeof(self.link)
class TabularSection(Section):
"""Section containing tabular data (csv/tsv content, or one sheet of
an xlsx workbook rendered as CSV)."""
type: Literal[SectionType.TABULAR] = SectionType.TABULAR
text: str # CSV representation in a string
link: str
def __sizeof__(self) -> int:
return sys.getsizeof(self.text) + sys.getsizeof(self.link)
class BasicExpertInfo(BaseModel):
"""Basic Information for the owner of a document, any of the fields can be left as None
Display fallback goes as follows:
@@ -171,7 +186,7 @@ class DocumentBase(BaseModel):
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
id: str | None = None
sections: list[TextSection | ImageSection]
sections: Sequence[TextSection | ImageSection | TabularSection]
source: DocumentSource | None = None
semantic_identifier: str # displayed in the UI as the main identifier for the doc
# TODO(andrei): Ideally we could improve this to where each value is just a
@@ -381,12 +396,9 @@ class IndexingDocument(Document):
)
else:
section_len = sum(
(
len(section.text)
if isinstance(section, TextSection) and section.text is not None
else 0
)
len(section.text) if section.text is not None else 0
for section in self.sections
if isinstance(section, (TextSection, TabularSection))
)
return title_len + section_len

View File

@@ -41,6 +41,10 @@ from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.app_configs import SHAREPOINT_CONNECTOR_SIZE_THRESHOLD
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
@@ -60,6 +64,7 @@ from onyx.connectors.models import ExternalAccess
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_access
from onyx.db.enums import HierarchyNodeType
@@ -586,7 +591,7 @@ def _convert_driveitem_to_document_with_permissions(
driveitem, f"Failed to download via graph api: {e}", e
)
sections: list[TextSection | ImageSection] = []
sections: list[TextSection | ImageSection | TabularSection] = []
file_ext = get_file_ext(driveitem.name)
if not content_bytes:
@@ -602,6 +607,19 @@ def _convert_driveitem_to_document_with_permissions(
)
image_section.link = driveitem.web_url
sections.append(image_section)
elif is_tabular_file(driveitem.name):
try:
sections.extend(
tabular_file_to_sections(
file=io.BytesIO(content_bytes),
file_name=driveitem.name,
link=driveitem.web_url or "",
)
)
except Exception as e:
logger.warning(
f"Failed to extract tabular sections for '{driveitem.name}': {e}"
)
else:
def _store_embedded_image(img_data: bytes, img_name: str) -> None:

View File

@@ -750,31 +750,3 @@ def resync_cc_pair(
)
db_session.commit()
# ── Metrics query helpers ──────────────────────────────────────────────
def get_connector_health_for_metrics(
db_session: Session,
) -> list: # Returns list of Row tuples
"""Return connector health data for Prometheus metrics.
Each row is (cc_pair_id, status, in_repeated_error_state,
last_successful_index_time, name, source).
"""
return (
db_session.query(
ConnectorCredentialPair.id,
ConnectorCredentialPair.status,
ConnectorCredentialPair.in_repeated_error_state,
ConnectorCredentialPair.last_successful_index_time,
ConnectorCredentialPair.name,
Connector.source,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.all()
)

View File

@@ -335,6 +335,7 @@ def update_document_set(
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
)
document_set_row.name = document_set_update_request.name
document_set_row.description = document_set_update_request.description
if not DISABLE_VECTOR_DB:
document_set_row.is_up_to_date = False

View File

@@ -11,6 +11,7 @@ from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_READONLY_PASSWORD
@@ -346,6 +347,25 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _safe_close_session(session: Session) -> None:
"""Close a session, catching connection-closed errors during cleanup.
Long-running operations (e.g. multi-model LLM loops) can hold a session
open for minutes. If the underlying connection is dropped by cloud
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
timeouts, etc.), the implicit rollback in Session.close() raises
OperationalError or InterfaceError. Since the work is already complete,
we log and move on — SQLAlchemy internally invalidates the connection
for pool recycling.
"""
try:
session.close()
except DBAPIError:
logger.warning(
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
)
@contextmanager
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
"""
@@ -358,8 +378,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
with Session(bind=engine, expire_on_commit=False) as session:
session = Session(bind=engine, expire_on_commit=False)
try:
yield session
finally:
_safe_close_session(session)
return
# Create connection with schema translation to handle querying the right schema
@@ -367,8 +390,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
with engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
with Session(bind=connection, expire_on_commit=False) as session:
session = Session(bind=connection, expire_on_commit=False)
try:
yield session
finally:
_safe_close_session(session)
def get_session() -> Generator[Session, None, None]:

View File

@@ -2,8 +2,6 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import NamedTuple
from typing import TYPE_CHECKING
from typing import TypeVarTuple
from sqlalchemy import and_
@@ -30,17 +28,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
if TYPE_CHECKING:
from onyx.configs.constants import DocumentSource
# from sqlalchemy.sql.selectable import Select
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
# from onyx.db.engine import async_query_for_dms
logger = setup_logger()
@@ -981,104 +968,48 @@ def get_index_attempt_errors_for_cc_pair(
return list(db_session.scalars(stmt).all())
# ── Metrics query helpers ──────────────────────────────────────────────
class ActiveIndexAttemptMetric(NamedTuple):
"""Row returned by get_active_index_attempts_for_metrics."""
status: IndexingStatus
source: "DocumentSource"
cc_pair_id: int
cc_pair_name: str | None
attempt_count: int
def get_active_index_attempts_for_metrics(
def get_index_attempt_errors_across_connectors(
db_session: Session,
) -> list[ActiveIndexAttemptMetric]:
"""Return non-terminal index attempts grouped by status, source, and connector.
cc_pair_id: int | None = None,
error_type: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
unresolved_only: bool = True,
page: int = 0,
page_size: int = 25,
) -> tuple[list[IndexAttemptError], int]:
"""Query index attempt errors across all connectors with optional filters.
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
Returns (errors, total_count) for pagination.
"""
from onyx.db.models import Connector
stmt = select(IndexAttemptError)
count_stmt = select(func.count()).select_from(IndexAttemptError)
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
rows = (
db_session.query(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
func.count(),
if cc_pair_id is not None:
stmt = stmt.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
count_stmt = count_stmt.where(
IndexAttemptError.connector_credential_pair_id == cc_pair_id
)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.filter(IndexAttempt.status.notin_(terminal_statuses))
.group_by(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
)
.all()
)
return [ActiveIndexAttemptMetric(*row) for row in rows]
if error_type is not None:
stmt = stmt.where(IndexAttemptError.error_type == error_type)
count_stmt = count_stmt.where(IndexAttemptError.error_type == error_type)
def get_failed_attempt_counts_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
if unresolved_only:
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
count_stmt = count_stmt.where(IndexAttemptError.is_resolved.is_(False))
When ``since`` is provided, only attempts created after that timestamp
are counted. Defaults to the last 90 days to avoid unbounded historical
aggregation.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
if start_time is not None:
stmt = stmt.where(IndexAttemptError.time_created >= start_time)
count_stmt = count_stmt.where(IndexAttemptError.time_created >= start_time)
rows = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.count(),
)
.filter(IndexAttempt.status == IndexingStatus.FAILED)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
.all()
)
return {cc_id: count for cc_id, count in rows}
if end_time is not None:
stmt = stmt.where(IndexAttemptError.time_created <= end_time)
count_stmt = count_stmt.where(IndexAttemptError.time_created <= end_time)
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
stmt = stmt.offset(page * page_size).limit(page_size)
def get_docs_indexed_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
Only counts attempts with status SUCCESS to avoid inflating counts with
partial results from failed attempts. When ``since`` is provided, only
attempts created after that timestamp are included.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
query = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
)
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
)
rows = query.all()
return {cc_id: int(total or 0) for cc_id, total in rows}
total = db_session.scalar(count_stmt) or 0
errors = list(db_session.scalars(stmt).all())
return errors, total

View File

@@ -5,6 +5,7 @@ from pydantic import ConfigDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.models import Memory
from onyx.db.models import User
@@ -83,47 +84,51 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
def add_memory(
user_id: UUID,
memory_text: str,
db_session: Session,
) -> Memory:
db_session: Session | None = None,
) -> int:
"""Insert a new Memory row for the given user.
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
one (lowest id) is deleted before inserting the new one.
Returns the id of the newly created Memory row.
"""
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
with get_session_with_current_tenant_if_none(db_session) as db_session:
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory.id
def update_memory_at_index(
user_id: UUID,
index: int,
new_text: str,
db_session: Session,
) -> Memory | None:
db_session: Session | None = None,
) -> int | None:
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
Returns the updated Memory row, or None if the index is out of range.
Returns the id of the updated Memory row, or None if the index is out of range.
"""
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
with get_session_with_current_tenant_if_none(db_session) as db_session:
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if index < 0 or index >= len(memory_rows):
return None
if index < 0 or index >= len(memory_rows):
return None
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory.id

View File

@@ -7,8 +7,6 @@ import time
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.citation_processor import DynamicCitationProcessor
@@ -22,6 +20,7 @@ from onyx.chat.models import LlmStepResult
from onyx.chat.models import ToolCallSimple
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
from onyx.configs.constants import MessageType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tools import get_tool_by_name
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
@@ -184,6 +183,14 @@ def generate_final_report(
return has_reasoned
def _get_research_agent_tool_id() -> int:
with get_session_with_current_tenant() as db_session:
return get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id
@log_function_time(print_only=True)
def run_deep_research_llm_loop(
emitter: Emitter,
@@ -193,7 +200,6 @@ def run_deep_research_llm_loop(
custom_agent_prompt: str | None, # noqa: ARG001
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
skip_clarification: bool = False,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
@@ -717,6 +723,7 @@ def run_deep_research_llm_loop(
simple_chat_history.append(assistant_with_tools)
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
research_agent_tool_id = _get_research_agent_tool_id()
for tab_index, report in enumerate(
research_results.intermediate_reports
):
@@ -737,10 +744,7 @@ def run_deep_research_llm_loop(
tab_index=tab_index,
tool_name=current_tool_call.tool_name,
tool_call_id=current_tool_call.tool_call_id,
tool_id=get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id,
tool_id=research_agent_tool_id,
reasoning_tokens=llm_step_result.reasoning
or most_recent_reasoning,
tool_call_arguments=current_tool_call.tool_args,

View File

@@ -379,13 +379,25 @@ def _worksheet_to_matrix(
worksheet: Worksheet,
) -> list[list[str]]:
"""
Converts a singular worksheet to a matrix of values
Converts a singular worksheet to a matrix of values.
Rows are padded to a uniform width. In openpyxl's read_only mode,
iter_rows can yield rows of differing lengths (trailing empty cells
are sometimes omitted), and downstream column cleanup assumes a
rectangular matrix.
"""
rows: list[list[str]] = []
max_len = 0
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
row = ["" if cell is None else str(cell) for cell in worksheet_row]
if len(row) > max_len:
max_len = len(row)
rows.append(row)
for row in rows:
if len(row) < max_len:
row.extend([""] * (max_len - len(row)))
return rows
@@ -463,29 +475,13 @@ def _remove_empty_runs(
return result
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
# TODO: switch back to this approach in a few months when markitdown
# fixes their handling of excel files
def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str, str]]:
"""
Converts each sheet in the excel file to a csv condensed string.
Returns a string and the worksheet title for each worksheet
# md = get_markitdown_converter()
# stream_info = StreamInfo(
# mimetype=SPREADSHEET_MIME_TYPE, filename=file_name or None, extension=".xlsx"
# )
# try:
# workbook = md.convert(to_bytesio(file), stream_info=stream_info)
# except (
# BadZipFile,
# ValueError,
# FileConversionException,
# UnsupportedFormatException,
# ) as e:
# error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
# if file_name.startswith("~"):
# logger.debug(error_str + " (this is expected for files with ~)")
# else:
# logger.warning(error_str)
# return ""
# return workbook.markdown
Returns a list of (csv_text, sheet)
"""
try:
workbook = openpyxl.load_workbook(file, read_only=True)
except BadZipFile as e:
@@ -494,23 +490,30 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
logger.debug(error_str + " (this is expected for files with ~)")
else:
logger.warning(error_str)
return ""
return []
except Exception as e:
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
logger.error(
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return ""
return []
raise
text_content = []
sheets: list[tuple[str, str]] = []
for sheet in workbook.worksheets:
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
buf = io.StringIO()
writer = csv.writer(buf, lineterminator="\n")
writer.writerows(sheet_matrix)
text_content.append(buf.getvalue().rstrip("\n"))
return TEXT_SECTION_SEPARATOR.join(text_content)
csv_text = buf.getvalue().rstrip("\n")
if csv_text.strip():
sheets.append((csv_text, sheet.title))
return sheets
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
sheets = xlsx_sheet_extraction(file, file_name)
return TEXT_SECTION_SEPARATOR.join(csv_text for csv_text, _title in sheets)
def eml_to_text(file: IO[Any]) -> str:

View File

@@ -7,6 +7,7 @@ from onyx.indexing.chunking.image_section_chunker import ImageChunker
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.tabular_section_chunker import TabularChunker
from onyx.indexing.chunking.text_section_chunker import TextChunker
from onyx.indexing.models import DocAwareChunk
from onyx.natural_language_processing.utils import BaseTokenizer
@@ -38,6 +39,7 @@ class DocumentChunker:
chunk_splitter=chunk_splitter,
),
SectionType.IMAGE: ImageChunker(),
SectionType.TABULAR: TabularChunker(tokenizer=tokenizer),
}
def chunk(
@@ -99,7 +101,9 @@ class DocumentChunker:
payloads.extend(result.payloads)
accumulator = result.accumulator
# Final flush — any leftover buffered text becomes one last payload.
payloads.extend(accumulator.flush_to_list())
return payloads
def _select_chunker(self, section: Section) -> SectionChunker:

View File

@@ -0,0 +1,272 @@
import csv
import io
from collections.abc import Iterable
from pydantic import BaseModel
from onyx.connectors.models import Section
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import split_text_by_tokens
from onyx.utils.logger import setup_logger
logger = setup_logger()
COLUMNS_MARKER = "Columns:"
FIELD_VALUE_SEPARATOR = ", "
ROW_JOIN = "\n"
NEWLINE_TOKENS = 1
class _ParsedRow(BaseModel):
header: list[str]
row: list[str]
class _TokenizedText(BaseModel):
text: str
token_count: int
def format_row(header: list[str], row: list[str]) -> str:
"""
A header-row combination is formatted like this:
field1=value1, field2=value2, field3=value3
"""
pairs = _row_to_pairs(header, row)
formatted = FIELD_VALUE_SEPARATOR.join(f"{h}={v}" for h, v in pairs)
return formatted
def format_columns_header(headers: list[str]) -> str:
"""
Format the column header line. Underscored headers get a
space-substituted friendly alias in parens.
Example:
headers = ["id", "MTTR_hours"]
=> "Columns: id, MTTR_hours (MTTR hours)"
"""
parts: list[str] = []
for header in headers:
friendly = header
if "_" in header:
friendly = f'{header} ({header.replace("_", " ")})'
parts.append(friendly)
return f"{COLUMNS_MARKER} " + FIELD_VALUE_SEPARATOR.join(parts)
def parse_section(section: Section) -> list[_ParsedRow]:
"""Parse CSV into headers + rows. First non-empty row is the header;
blank rows are skipped."""
section_text = section.text or ""
if not section_text.strip():
return []
reader = csv.reader(io.StringIO(section_text))
non_empty_rows = [row for row in reader if any(cell.strip() for cell in row)]
if not non_empty_rows:
return []
header, *data_rows = non_empty_rows
return [_ParsedRow(header=header, row=row) for row in data_rows]
def _row_to_pairs(headers: list[str], row: list[str]) -> list[tuple[str, str]]:
return [(h, v) for h, v in zip(headers, row) if v.strip()]
def pack_chunk(chunk: str, new_row: str) -> str:
return chunk + "\n" + new_row
def _split_row_by_pairs(
pairs: list[tuple[str, str]],
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[_TokenizedText]:
"""Greedily pack pairs into max-sized pieces. Any single pair that
itself exceeds ``max_tokens`` is token-split at id boundaries.
No headers."""
separator_tokens = count_tokens(FIELD_VALUE_SEPARATOR, tokenizer)
pieces: list[_TokenizedText] = []
current_parts: list[str] = []
current_tokens = 0
for pair in pairs:
pair_str = f"{pair[0]}={pair[1]}"
pair_tokens = count_tokens(pair_str, tokenizer)
increment = pair_tokens if not current_parts else separator_tokens + pair_tokens
if current_tokens + increment <= max_tokens:
current_parts.append(pair_str)
current_tokens += increment
continue
if current_parts:
pieces.append(
_TokenizedText(
text=FIELD_VALUE_SEPARATOR.join(current_parts),
token_count=current_tokens,
)
)
current_parts = []
current_tokens = 0
if pair_tokens > max_tokens:
for split_text in split_text_by_tokens(pair_str, tokenizer, max_tokens):
pieces.append(
_TokenizedText(
text=split_text,
token_count=count_tokens(split_text, tokenizer),
)
)
else:
current_parts = [pair_str]
current_tokens = pair_tokens
if current_parts:
pieces.append(
_TokenizedText(
text=FIELD_VALUE_SEPARATOR.join(current_parts),
token_count=current_tokens,
)
)
return pieces
def _build_chunk_from_scratch(
pairs: list[tuple[str, str]],
formatted_row: str,
row_tokens: int,
column_header: str,
column_header_tokens: int,
sheet_header: str,
sheet_header_tokens: int,
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[_TokenizedText]:
# 1. Row alone is too large — split by pairs, no headers.
if row_tokens > max_tokens:
return _split_row_by_pairs(pairs, tokenizer, max_tokens)
chunk = formatted_row
chunk_tokens = row_tokens
# 2. Attempt to add column header
candidate_tokens = column_header_tokens + NEWLINE_TOKENS + chunk_tokens
if candidate_tokens <= max_tokens:
chunk = column_header + ROW_JOIN + chunk
chunk_tokens = candidate_tokens
# 3. Attempt to add sheet header
if sheet_header:
candidate_tokens = sheet_header_tokens + NEWLINE_TOKENS + chunk_tokens
if candidate_tokens <= max_tokens:
chunk = sheet_header + ROW_JOIN + chunk
chunk_tokens = candidate_tokens
return [_TokenizedText(text=chunk, token_count=chunk_tokens)]
def parse_to_chunks(
rows: Iterable[_ParsedRow],
sheet_header: str,
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[str]:
rows_list = list(rows)
if not rows_list:
return []
column_header = format_columns_header(rows_list[0].header)
column_header_tokens = count_tokens(column_header, tokenizer)
sheet_header_tokens = count_tokens(sheet_header, tokenizer) if sheet_header else 0
chunks: list[str] = []
current_chunk = ""
current_chunk_tokens = 0
for row in rows_list:
pairs: list[tuple[str, str]] = _row_to_pairs(row.header, row.row)
formatted = format_row(row.header, row.row)
row_tokens = count_tokens(formatted, tokenizer)
if current_chunk:
# Attempt to pack it in (additive approximation)
if current_chunk_tokens + NEWLINE_TOKENS + row_tokens <= max_tokens:
current_chunk = pack_chunk(current_chunk, formatted)
current_chunk_tokens += NEWLINE_TOKENS + row_tokens
continue
# Doesn't fit — flush and start new
chunks.append(current_chunk)
current_chunk = ""
current_chunk_tokens = 0
# Build chunk from scratch
for piece in _build_chunk_from_scratch(
pairs=pairs,
formatted_row=formatted,
row_tokens=row_tokens,
column_header=column_header,
column_header_tokens=column_header_tokens,
sheet_header=sheet_header,
sheet_header_tokens=sheet_header_tokens,
tokenizer=tokenizer,
max_tokens=max_tokens,
):
if current_chunk:
chunks.append(current_chunk)
current_chunk = piece.text
current_chunk_tokens = piece.token_count
# Flush remaining
if current_chunk:
chunks.append(current_chunk)
return chunks
class TabularChunker(SectionChunker):
def __init__(self, tokenizer: BaseTokenizer) -> None:
self.tokenizer = tokenizer
def chunk_section(
self,
section: Section,
accumulator: AccumulatorState,
content_token_limit: int,
) -> SectionChunkerOutput:
payloads = accumulator.flush_to_list()
parsed_rows = parse_section(section)
if not parsed_rows:
logger.warning(
f"TabularChunker: skipping unparseable section (link={section.link})"
)
return SectionChunkerOutput(
payloads=payloads, accumulator=AccumulatorState()
)
sheet_header = section.heading or ""
chunk_texts = parse_to_chunks(
rows=parsed_rows,
sheet_header=sheet_header,
tokenizer=self.tokenizer,
max_tokens=content_token_limit,
)
for i, text in enumerate(chunk_texts):
payloads.append(
ChunkPayload(
text=text,
links={0: section.link or ""},
is_continuation=(i > 0),
)
)
return SectionChunkerOutput(payloads=payloads, accumulator=AccumulatorState())

View File

@@ -10,6 +10,7 @@ from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import split_text_by_tokens
from onyx.utils.text_processing import clean_text
from onyx.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
@@ -90,8 +91,8 @@ class TextChunker(SectionChunker):
STRICT_CHUNK_TOKEN_LIMIT
and count_tokens(split_text, self.tokenizer) > content_token_limit
):
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
smaller_chunks = split_text_by_tokens(
split_text, self.tokenizer, content_token_limit
)
for j, small_chunk in enumerate(smaller_chunks):
payloads.append(
@@ -114,16 +115,3 @@ class TextChunker(SectionChunker):
payloads=payloads,
accumulator=AccumulatorState(),
)
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
tokens = self.tokenizer.tokenize(text)
chunks: list[str] = []
start = 0
total_tokens = len(tokens)
while start < total_tokens:
end = min(start + content_token_limit, total_tokens)
token_chunk = tokens[start:end]
chunk_text = " ".join(token_chunk)
chunks.append(chunk_text)
start = end
return chunks

View File

@@ -3,6 +3,8 @@ from abc import ABC
from abc import abstractmethod
from collections import defaultdict
import sentry_sdk
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocumentFailure
@@ -291,6 +293,13 @@ def embed_chunks_with_failure_handling(
)
embedded_chunks.extend(doc_embedded_chunks)
except Exception as e:
with sentry_sdk.new_scope() as scope:
scope.set_tag("stage", "embedding")
scope.set_tag("doc_id", doc_id)
if tenant_id:
scope.set_tag("tenant_id", tenant_id)
scope.fingerprint = ["embedding-failure", type(e).__name__]
sentry_sdk.capture_exception(e)
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
failures.append(
ConnectorFailure(

View File

@@ -5,6 +5,7 @@ from collections.abc import Iterator
from contextlib import contextmanager
from typing import Protocol
import sentry_sdk
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
@@ -332,6 +333,13 @@ def index_doc_batch_with_handler(
except Exception as e:
# don't log the batch directly, it's too much text
document_ids = [doc.id for doc in document_batch]
with sentry_sdk.new_scope() as scope:
scope.set_tag("stage", "indexing_pipeline")
scope.set_tag("tenant_id", tenant_id)
scope.set_tag("batch_size", str(len(document_batch)))
scope.set_extra("document_ids", document_ids)
scope.fingerprint = ["indexing-pipeline-failure", type(e).__name__]
sentry_sdk.capture_exception(e)
logger.exception(f"Failed to index document batch: {document_ids}")
index_pipeline_result = IndexingPipelineResult(
@@ -543,7 +551,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
processed_sections=[
Section(
type=section.type,
text=section.text if isinstance(section, TextSection) else "",
text="" if isinstance(section, ImageSection) else section.text,
link=section.link,
image_file_id=(
section.image_file_id
@@ -609,7 +617,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
processed_sections.append(processed_section)
# For TextSection, create a base Section with text and link
elif isinstance(section, TextSection):
else:
processed_section = Section(
type=section.type,
text=section.text or "", # Ensure text is always a string, not None

View File

@@ -6,6 +6,7 @@ from itertools import chain
from itertools import groupby
import httpx
import sentry_sdk
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocumentFailure
@@ -88,6 +89,12 @@ def write_chunks_to_vector_db_with_backoff(
)
)
except Exception as e:
with sentry_sdk.new_scope() as scope:
scope.set_tag("stage", "vector_db_write")
scope.set_tag("doc_id", doc_id)
scope.set_tag("tenant_id", index_batch_params.tenant_id)
scope.fingerprint = ["vector-db-write-failure", type(e).__name__]
sentry_sdk.capture_exception(e)
logger.exception(
f"Failed to write document chunks for '{doc_id}' to vector db"
)

View File

@@ -434,11 +434,14 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
lifespan=lifespan_override or lifespan,
)
if SENTRY_DSN:
from onyx.configs.sentry import _add_instance_tags
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
before_send=_add_instance_tags,
)
logger.info("Sentry initialized")
else:

View File

@@ -201,6 +201,33 @@ def count_tokens(
return total
def split_text_by_tokens(
text: str,
tokenizer: BaseTokenizer,
max_tokens: int,
) -> list[str]:
"""Split ``text`` into pieces of ≤ ``max_tokens`` tokens each, via
encode/decode at token-id boundaries.
Note: the returned pieces are not strictly guaranteed to re-tokenize to
≤ max_tokens. BPE merges at window boundaries may drift by a few tokens,
and cuts landing mid-multi-byte-UTF-8-character produce replacement
characters on decode. Good enough for "best-effort" splitting of
oversized content, not for hard limit enforcement.
"""
if not text:
return []
token_ids: list[int] = []
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
token_ids.extend(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
return [
tokenizer.decode(token_ids[start : start + max_tokens])
for start in range(0, len(token_ids), max_tokens)
]
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:

View File

@@ -125,6 +125,11 @@ class TenantRedis(redis.Redis):
"sadd",
"srem",
"scard",
"zadd",
"zrangebyscore",
"zremrangebyscore",
"zscore",
"zcard",
"hexists",
"hset",
"hdel",

View File

@@ -0,0 +1,104 @@
"""Redis helpers for the tenant work-gating feature.
One sorted set `active_tenants` under the cloud Redis tenant tracks the last
time each tenant was observed doing work. The fanout generator reads the set
(filtered to entries within a TTL window) and skips tenants that haven't been
active recently.
All public functions no-op in single-tenant mode (`MULTI_TENANT=False`).
"""
import time
from typing import cast
from redis.client import Redis
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
# Unprefixed key. `TenantRedis._prefixed` prepends `cloud:` at call time so
# the full rendered key is `cloud:active_tenants`.
_SET_KEY = "active_tenants"
def _now_ms() -> int:
return int(time.time() * 1000)
def _client() -> Redis:
return get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
def mark_tenant_active(tenant_id: str) -> None:
"""Record that `tenant_id` was just observed doing work (ZADD with the
current timestamp as the score). Best-effort — a Redis failure is logged
and swallowed so it never breaks a writer path.
Call sites:
- Top of each gated beat-task consumer when its "is there work?" query
returns a non-empty result.
- cc_pair create lifecycle hook.
"""
if not MULTI_TENANT:
return
try:
# `mapping={member: score}` syntax; ZADD overwrites the score on
# existing members, which is exactly the refresh semantics we want.
_client().zadd(_SET_KEY, mapping={tenant_id: _now_ms()})
except Exception:
logger.exception(f"mark_tenant_active failed: tenant_id={tenant_id}")
def get_active_tenants(ttl_seconds: int) -> set[str] | None:
"""Return tenants whose last-seen timestamp is within `ttl_seconds` of
now.
Return values:
- `set[str]` (possibly empty) — Redis read succeeded. Empty set means
no tenants are currently marked active; callers should *skip* all
tenants if the gate is enforcing.
- `None` — Redis read failed *or* we are in single-tenant mode. Callers
should fail open (dispatch to every tenant this cycle). Distinguishing
failure from "genuinely empty" prevents a Redis outage from silently
starving every tenant on every enforced cycle.
"""
if not MULTI_TENANT:
return None
cutoff_ms = _now_ms() - (ttl_seconds * 1000)
try:
raw = cast(
list[bytes],
_client().zrangebyscore(_SET_KEY, min=cutoff_ms, max="+inf"),
)
except Exception:
logger.exception("get_active_tenants failed")
return None
return {m.decode() if isinstance(m, bytes) else m for m in raw}
def cleanup_expired(ttl_seconds: int) -> int:
"""Remove members older than `ttl_seconds` from the set. Optional
memory-hygiene helper — correctness does not depend on calling this, but
without it the set grows unboundedly as old tenants accumulate. Returns
the number of members removed."""
if not MULTI_TENANT:
return 0
cutoff_ms = _now_ms() - (ttl_seconds * 1000)
try:
removed = cast(
int,
_client().zremrangebyscore(_SET_KEY, min="-inf", max=f"({cutoff_ms}"),
)
return removed
except Exception:
logger.exception("cleanup_expired failed")
return 0

View File

@@ -63,6 +63,7 @@ class DocumentSetCreationRequest(BaseModel):
class DocumentSetUpdateRequest(BaseModel):
id: int
name: str
description: str
cc_pair_ids: list[int]
is_public: bool

View File

@@ -11,6 +11,9 @@ from onyx.db.notification import dismiss_notification
from onyx.db.notification import get_notification_by_id
from onyx.db.notification import get_notifications
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
from onyx.server.features.notifications.utils import (
ensure_permissions_migration_notification,
)
from onyx.server.features.release_notes.utils import (
ensure_release_notes_fresh_and_notify,
)
@@ -49,6 +52,13 @@ def get_notifications_api(
except Exception:
logger.exception("Failed to check for release notes in notifications endpoint")
try:
ensure_permissions_migration_notification(user, db_session)
except Exception:
logger.exception(
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
)
notifications = [
NotificationModel.from_model(notif)
for notif in get_notifications(user, db_session, include_dismissed=True)

View File

@@ -0,0 +1,21 @@
from sqlalchemy.orm import Session
from onyx.configs.constants import NotificationType
from onyx.db.models import User
from onyx.db.notification import create_notification
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
# Feature id "permissions_migration_v1" must not change after shipping —
# it is the dedup key on (user_id, notif_type, additional_data).
create_notification(
user_id=user.id,
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
db_session=db_session,
title="Permissions are changing in Onyx",
description="Roles are moving to group-based permissions. Click for details.",
additional_data={
"feature": "permissions_migration_v1",
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
},
)

View File

@@ -185,6 +185,10 @@ class MinimalPersonaSnapshot(BaseModel):
for doc_set in persona.document_sets:
for cc_pair in doc_set.connector_credential_pairs:
sources.add(cc_pair.connector.source)
for fed_ds in doc_set.federated_connectors:
non_fed = fed_ds.federated_connector.source.to_non_federated_source()
if non_fed is not None:
sources.add(non_fed)
# Sources from hierarchy nodes
for node in persona.hierarchy_nodes:
@@ -195,6 +199,9 @@ class MinimalPersonaSnapshot(BaseModel):
if doc.parent_hierarchy_node:
sources.add(doc.parent_hierarchy_node.source)
if persona.user_files:
sources.add(DocumentSource.USER_FILE)
return MinimalPersonaSnapshot(
# Core fields actually used by ChatPage
id=persona.id,

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.auth.users import current_curator_or_admin_user
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
@@ -28,6 +29,7 @@ from onyx.db.feedback import fetch_docs_ranked_by_boost_for_user
from onyx.db.feedback import update_document_boost_for_user
from onyx.db.feedback import update_document_hidden_for_user
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
from onyx.db.index_attempt import get_index_attempt_errors_across_connectors
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.factory import get_kv_store
@@ -35,6 +37,7 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.llm.factory import get_default_llm
from onyx.llm.utils import test_llm
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import BoostDoc
from onyx.server.manage.models import BoostUpdateRequest
from onyx.server.manage.models import HiddenUpdateRequest
@@ -206,3 +209,40 @@ def create_deletion_attempt_for_connector_id(
file_store = get_default_file_store()
for file_id in connector.connector_specific_config.get("file_locations", []):
file_store.delete_file(file_id)
@router.get("/admin/indexing/failed-documents")
def get_failed_documents(
cc_pair_id: int | None = None,
error_type: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
include_resolved: bool = False,
page_num: int = 0,
page_size: int = 25,
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[IndexAttemptErrorPydantic]:
"""Get indexing errors across all connectors with optional filters.
Provides a cross-connector view of document indexing failures.
Defaults to last 30 days if no start_time is provided to avoid
unbounded count queries.
"""
if start_time is None:
start_time = datetime.now(tz=timezone.utc) - timedelta(days=30)
errors, total = get_index_attempt_errors_across_connectors(
db_session=db_session,
cc_pair_id=cc_pair_id,
error_type=error_type,
start_time=start_time,
end_time=end_time,
unresolved_only=not include_resolved,
page=page_num,
page_size=page_size,
)
return PaginatedReturn(
items=[IndexAttemptErrorPydantic.from_model(e) for e in errors],
total_items=total,
)

View File

@@ -183,6 +183,9 @@ def generate_ollama_display_name(model_name: str) -> str:
"qwen2.5:7b""Qwen 2.5 7B"
"mistral:latest""Mistral"
"deepseek-r1:14b""DeepSeek R1 14B"
"gemma4:e4b""Gemma 4 E4B"
"deepseek-v3.1:671b-cloud""DeepSeek V3.1 671B Cloud"
"qwen3-vl:235b-instruct-cloud""Qwen 3-vl 235B Instruct Cloud"
"""
# Split into base name and tag
if ":" in model_name:
@@ -209,13 +212,24 @@ def generate_ollama_display_name(model_name: str) -> str:
# Default: Title case with dashes converted to spaces
display_name = base.replace("-", " ").title()
# Process tag to extract size info (skip "latest")
# Process tag (skip "latest")
if tag and tag.lower() != "latest":
# Extract size like "7b", "70b", "14b"
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])", tag)
# Check for size prefix like "7b", "70b", optionally followed by modifiers
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])(-.+)?$", tag)
if size_match:
size = size_match.group(1).upper()
display_name = f"{display_name} {size}"
remainder = size_match.group(2)
if remainder:
# Format modifiers like "-cloud", "-instruct-cloud"
modifiers = " ".join(
p.title() for p in remainder.strip("-").split("-") if p
)
display_name = f"{display_name} {size} {modifiers}"
else:
display_name = f"{display_name} {size}"
else:
# Non-size tags like "e4b", "q4_0", "fp16", "cloud"
display_name = f"{display_name} {tag.upper()}"
return display_name

View File

@@ -1,13 +1,14 @@
import json
import secrets
from collections.abc import AsyncIterator
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Query
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
@@ -113,28 +114,47 @@ async def transcribe_audio(
) from exc
def _extract_provider_error(exc: Exception) -> str:
"""Extract a human-readable message from a provider exception.
Provider errors often embed JSON from upstream APIs (e.g. ElevenLabs).
This tries to parse a readable ``message`` field out of common JSON
error shapes; falls back to ``str(exc)`` if nothing better is found.
"""
raw = str(exc)
try:
# Many providers embed JSON after a prefix like "ElevenLabs TTS failed: {...}"
json_start = raw.find("{")
if json_start == -1:
return raw
parsed = json.loads(raw[json_start:])
# Shape: {"detail": {"message": "..."}} (ElevenLabs)
detail = parsed.get("detail", parsed)
if isinstance(detail, dict):
return detail.get("message") or detail.get("error") or raw
if isinstance(detail, str):
return detail
except (json.JSONDecodeError, AttributeError, TypeError):
pass
return raw
class SynthesizeRequest(BaseModel):
text: str = Field(..., min_length=1)
voice: str | None = None
speed: float | None = Field(default=None, ge=0.5, le=2.0)
@router.post("/synthesize")
async def synthesize_speech(
text: str | None = Query(
default=None, description="Text to synthesize", max_length=4096
),
voice: str | None = Query(default=None, description="Voice ID to use"),
speed: float | None = Query(
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
),
body: SynthesizeRequest,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
) -> StreamingResponse:
"""
Synthesize text to speech using the default TTS provider.
Accepts parameters via query string for streaming compatibility.
"""
logger.info(
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
)
if not text:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
"""Synthesize text to speech using the default TTS provider."""
text = body.text
voice = body.voice
speed = body.speed
logger.info(f"TTS request: text length={len(text)}, voice={voice}, speed={speed}")
# Use short-lived session to fetch provider config, then release connection
# before starting the long-running streaming response
@@ -177,31 +197,36 @@ async def synthesize_speech(
logger.error(f"Failed to get voice provider: {exc}")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
# Session is now closed - streaming response won't hold DB connection
# Pull the first chunk before returning the StreamingResponse. If the
# provider rejects the request (e.g. text too long), the error surfaces
# as a proper HTTP error instead of a broken audio stream.
stream_iter = provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
)
try:
first_chunk = await stream_iter.__anext__()
except StopAsyncIteration:
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "TTS provider returned no audio")
except Exception as exc:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, _extract_provider_error(exc)
) from exc
async def audio_stream() -> AsyncIterator[bytes]:
try:
chunk_count = 0
async for chunk in provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
):
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
except NotImplementedError as exc:
logger.error(f"TTS not implemented: {exc}")
raise
except Exception as exc:
logger.error(f"Synthesis failed: {exc}")
raise
yield first_chunk
chunk_count = 1
async for chunk in stream_iter:
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
return StreamingResponse(
audio_stream(),
media_type="audio/mpeg",
headers={
"Content-Disposition": "inline; filename=speech.mp3",
# Allow streaming by not setting content-length
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable nginx buffering
"X-Accel-Buffering": "no",
},
)

View File

@@ -1,7 +1,8 @@
"""Generic Celery task lifecycle Prometheus metrics.
Provides signal handlers that track task started/completed/failed counts,
active task gauge, task duration histograms, and retry/reject/revoke counts.
active task gauge, task duration histograms, queue wait time histograms,
and retry/reject/revoke counts.
These fire for ALL tasks on the worker — no per-connector enrichment
(see indexing_task_metrics.py for that).
@@ -71,6 +72,32 @@ TASK_REJECTED = Counter(
["task_name"],
)
TASK_QUEUE_WAIT = Histogram(
"onyx_celery_task_queue_wait_seconds",
"Time a Celery task spent waiting in the queue before execution started",
["task_name", "queue"],
buckets=[
0.1,
0.5,
1,
5,
30,
60,
300,
600,
1800,
3600,
7200,
14400,
28800,
43200,
86400,
172800,
432000,
864000,
],
)
# task_id → (monotonic start time, metric labels)
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
@@ -133,6 +160,13 @@ def on_celery_task_prerun(
with _task_start_times_lock:
_evict_stale_start_times()
_task_start_times[task_id] = (time.monotonic(), labels)
headers = getattr(task.request, "headers", None) or {}
enqueued_at = headers.get("enqueued_at")
if isinstance(enqueued_at, (int, float)):
TASK_QUEUE_WAIT.labels(**labels).observe(
max(0.0, time.time() - enqueued_at)
)
except Exception:
logger.debug("Failed to record celery task prerun metrics", exc_info=True)

View File

@@ -0,0 +1,123 @@
"""Prometheus metrics for connector health and index attempts.
Emitted by docfetching and docprocessing workers when connector or
index attempt state changes. All functions silently catch exceptions
to avoid disrupting the caller's business logic.
Gauge metrics (error state, last success timestamp) are per-process.
With multiple worker pods, use max() aggregation in PromQL to get the
correct value across instances, e.g.:
max by (cc_pair_id, connector_name) (onyx_connector_in_error_state)
Unlike the per-task counters in indexing_task_metrics.py, these metrics
include connector_name because their cardinality is bounded by the number
of connectors (one series per connector), not by the number of task
executions.
"""
from prometheus_client import Counter
from prometheus_client import Gauge
from onyx.utils.logger import setup_logger
logger = setup_logger()
_CONNECTOR_LABELS = ["tenant_id", "source", "cc_pair_id", "connector_name"]
# --- Index attempt lifecycle ---
INDEX_ATTEMPT_STATUS = Counter(
"onyx_index_attempt_transitions_total",
"Index attempt status transitions",
[*_CONNECTOR_LABELS, "status"],
)
# --- Connector health ---
CONNECTOR_IN_ERROR_STATE = Gauge(
"onyx_connector_in_error_state",
"Whether the connector is in a repeated error state (1=yes, 0=no)",
_CONNECTOR_LABELS,
)
CONNECTOR_LAST_SUCCESS_TIMESTAMP = Gauge(
"onyx_connector_last_success_timestamp_seconds",
"Unix timestamp of last successful indexing for this connector",
_CONNECTOR_LABELS,
)
CONNECTOR_DOCS_INDEXED = Counter(
"onyx_connector_docs_indexed_total",
"Total documents indexed per connector (monotonic)",
_CONNECTOR_LABELS,
)
CONNECTOR_INDEXING_ERRORS = Counter(
"onyx_connector_indexing_errors_total",
"Total failed index attempts per connector (monotonic)",
_CONNECTOR_LABELS,
)
def on_index_attempt_status_change(
tenant_id: str,
source: str,
cc_pair_id: int,
connector_name: str,
status: str,
) -> None:
"""Called on any index attempt status transition."""
try:
labels = {
"tenant_id": tenant_id,
"source": source,
"cc_pair_id": str(cc_pair_id),
"connector_name": connector_name,
}
INDEX_ATTEMPT_STATUS.labels(**labels, status=status).inc()
if status == "failed":
CONNECTOR_INDEXING_ERRORS.labels(**labels).inc()
except Exception:
logger.debug("Failed to record index attempt status metric", exc_info=True)
def on_connector_error_state_change(
tenant_id: str,
source: str,
cc_pair_id: int,
connector_name: str,
in_error: bool,
) -> None:
"""Called when a connector's in_repeated_error_state changes."""
try:
CONNECTOR_IN_ERROR_STATE.labels(
tenant_id=tenant_id,
source=source,
cc_pair_id=str(cc_pair_id),
connector_name=connector_name,
).set(1.0 if in_error else 0.0)
except Exception:
logger.debug("Failed to record connector error state metric", exc_info=True)
def on_connector_indexing_success(
tenant_id: str,
source: str,
cc_pair_id: int,
connector_name: str,
docs_indexed: int,
success_timestamp: float,
) -> None:
"""Called when an indexing run completes successfully."""
try:
labels = {
"tenant_id": tenant_id,
"source": source,
"cc_pair_id": str(cc_pair_id),
"connector_name": connector_name,
}
CONNECTOR_LAST_SUCCESS_TIMESTAMP.labels(**labels).set(success_timestamp)
if docs_indexed > 0:
CONNECTOR_DOCS_INDEXED.labels(**labels).inc(docs_indexed)
except Exception:
logger.debug("Failed to record connector success metric", exc_info=True)

View File

@@ -0,0 +1,104 @@
"""Connector-deletion-specific Prometheus metrics.
Tracks the deletion lifecycle:
1. Deletions started (taskset generated)
2. Deletions completed (success or failure)
3. Taskset duration (from taskset generation to completion or failure).
Note: this measures the most recent taskset execution, NOT wall-clock
time since the user triggered the deletion. When deletion is blocked by
indexing/pruning/permissions, the fence is cleared and a fresh taskset
is generated on each retry, resetting this timer.
4. Deletion blocked by dependencies (indexing, pruning, permissions, etc.)
5. Fence resets (stuck deletion recovery)
All metrics are labeled by tenant_id. cc_pair_id is intentionally excluded
to avoid unbounded cardinality.
Usage:
from onyx.server.metrics.deletion_metrics import (
inc_deletion_started,
inc_deletion_completed,
observe_deletion_taskset_duration,
inc_deletion_blocked,
inc_deletion_fence_reset,
)
"""
from prometheus_client import Counter
from prometheus_client import Histogram
from onyx.utils.logger import setup_logger
logger = setup_logger()
DELETION_STARTED = Counter(
"onyx_deletion_started_total",
"Connector deletions initiated (taskset generated)",
["tenant_id"],
)
DELETION_COMPLETED = Counter(
"onyx_deletion_completed_total",
"Connector deletions completed",
["tenant_id", "outcome"],
)
DELETION_TASKSET_DURATION = Histogram(
"onyx_deletion_taskset_duration_seconds",
"Duration of a connector deletion taskset, from taskset generation "
"to completion or failure. Does not include time spent blocked on "
"indexing/pruning/permissions before the taskset was generated.",
["tenant_id", "outcome"],
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600, 7200, 21600],
)
DELETION_BLOCKED = Counter(
"onyx_deletion_blocked_total",
"Times deletion was blocked by a dependency",
["tenant_id", "blocker"],
)
DELETION_FENCE_RESET = Counter(
"onyx_deletion_fence_reset_total",
"Deletion fences reset due to missing celery tasks",
["tenant_id"],
)
def inc_deletion_started(tenant_id: str) -> None:
try:
DELETION_STARTED.labels(tenant_id=tenant_id).inc()
except Exception:
logger.debug("Failed to record deletion started", exc_info=True)
def inc_deletion_completed(tenant_id: str, outcome: str) -> None:
try:
DELETION_COMPLETED.labels(tenant_id=tenant_id, outcome=outcome).inc()
except Exception:
logger.debug("Failed to record deletion completed", exc_info=True)
def observe_deletion_taskset_duration(
tenant_id: str, outcome: str, duration_seconds: float
) -> None:
try:
DELETION_TASKSET_DURATION.labels(tenant_id=tenant_id, outcome=outcome).observe(
duration_seconds
)
except Exception:
logger.debug("Failed to record deletion taskset duration", exc_info=True)
def inc_deletion_blocked(tenant_id: str, blocker: str) -> None:
try:
DELETION_BLOCKED.labels(tenant_id=tenant_id, blocker=blocker).inc()
except Exception:
logger.debug("Failed to record deletion blocked", exc_info=True)
def inc_deletion_fence_reset(tenant_id: str) -> None:
try:
DELETION_FENCE_RESET.labels(tenant_id=tenant_id).inc()
except Exception:
logger.debug("Failed to record deletion fence reset", exc_info=True)

View File

@@ -1,25 +1,30 @@
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
"""Prometheus collectors for Celery queue depths and infrastructure health.
These collectors query Redis and Postgres at scrape time (the Collector pattern),
These collectors query Redis at scrape time (the Collector pattern),
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
monitoring celery worker which already has Redis and DB access.
monitoring celery worker which already has Redis access.
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
To avoid hammering Redis on every 15s scrape, results are cached with
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
stale, which is fine for monitoring dashboards.
Note: connector health and index attempt metrics are push-based (emitted by
workers at state-change time) and live in connector_health_metrics.py.
"""
from __future__ import annotations
import concurrent.futures
import json
import threading
import time
from datetime import datetime
from datetime import timezone
from typing import Any
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from redis import Redis
from onyx.background.celery.celery_redis import celery_get_broker_client
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.constants import OnyxCeleryQueues
@@ -31,6 +36,11 @@ logger = setup_logger()
# the previous result without re-querying Redis/Postgres.
_DEFAULT_CACHE_TTL = 30.0
# Maximum time (seconds) a single _collect_fresh() call may take before
# the collector gives up and returns stale/empty results. Prevents the
# /metrics endpoint from hanging indefinitely when a DB or Redis query stalls.
_DEFAULT_COLLECT_TIMEOUT = 120.0
_QUEUE_LABEL_MAP: dict[str, str] = {
OnyxCeleryQueues.PRIMARY: "primary",
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
@@ -62,18 +72,32 @@ _UNACKED_QUEUES: list[str] = [
class _CachedCollector(Collector):
"""Base collector with TTL-based caching.
"""Base collector with TTL-based caching and timeout protection.
Subclasses implement ``_collect_fresh()`` to query the actual data source.
The base ``collect()`` returns cached results if the TTL hasn't expired,
avoiding repeated queries when Prometheus scrapes frequently.
A per-collection timeout prevents a slow DB or Redis query from blocking
the /metrics endpoint indefinitely. If _collect_fresh() exceeds the
timeout, stale cached results are returned instead.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
def __init__(
self,
cache_ttl: float = _DEFAULT_CACHE_TTL,
collect_timeout: float = _DEFAULT_COLLECT_TIMEOUT,
) -> None:
self._cache_ttl = cache_ttl
self._collect_timeout = collect_timeout
self._cached_result: list[GaugeMetricFamily] | None = None
self._last_collect_time: float = 0.0
self._lock = threading.Lock()
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=1,
thread_name_prefix=type(self).__name__,
)
self._inflight: concurrent.futures.Future | None = None
def collect(self) -> list[GaugeMetricFamily]:
with self._lock:
@@ -84,12 +108,28 @@ class _CachedCollector(Collector):
):
return self._cached_result
# If a previous _collect_fresh() is still running, wait on it
# rather than queuing another. This prevents unbounded task
# accumulation in the executor during extended DB outages.
if self._inflight is not None and not self._inflight.done():
future = self._inflight
else:
future = self._executor.submit(self._collect_fresh)
self._inflight = future
try:
result = self._collect_fresh()
result = future.result(timeout=self._collect_timeout)
self._inflight = None
self._cached_result = result
self._last_collect_time = now
return result
except concurrent.futures.TimeoutError:
logger.warning(
f"{type(self).__name__}._collect_fresh() timed out after {self._collect_timeout}s, returning stale cache"
)
return self._cached_result if self._cached_result is not None else []
except Exception:
self._inflight = None
logger.exception(f"Error in {type(self).__name__}.collect()")
# Return stale cache on error rather than nothing — avoids
# metrics disappearing during transient failures.
@@ -117,8 +157,6 @@ class QueueDepthCollector(_CachedCollector):
if self._celery_app is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
depth = GaugeMetricFamily(
@@ -194,208 +232,6 @@ class QueueDepthCollector(_CachedCollector):
return None
class IndexAttemptCollector(_CachedCollector):
"""Queries Postgres for index attempt state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
self._terminal_statuses: list = []
def configure(self) -> None:
"""Call once DB engine is initialized."""
from onyx.db.enums import IndexingStatus
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
attempts_gauge = GaugeMetricFamily(
"onyx_index_attempts_active",
"Number of non-terminal index attempts",
labels=[
"status",
"source",
"tenant_id",
"connector_name",
"cc_pair_id",
],
)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
rows = get_active_index_attempts_for_metrics(session)
for status, source, cc_id, cc_name, count in rows:
name_val = cc_name or f"cc_pair_{cc_id}"
attempts_gauge.add_metric(
[
status.value,
source.value,
tid,
name_val,
str(cc_id),
],
count,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [attempts_gauge]
class ConnectorHealthCollector(_CachedCollector):
"""Queries Postgres for connector health state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
def configure(self) -> None:
"""Call once DB engine is initialized."""
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.connector_credential_pair import (
get_connector_health_for_metrics,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
staleness_gauge = GaugeMetricFamily(
"onyx_connector_last_success_age_seconds",
"Seconds since last successful index for this connector",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
error_state_gauge = GaugeMetricFamily(
"onyx_connector_in_error_state",
"Whether the connector is in a repeated error state (1=yes, 0=no)",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
by_status_gauge = GaugeMetricFamily(
"onyx_connectors_by_status",
"Number of connectors grouped by status",
labels=["tenant_id", "status"],
)
error_total_gauge = GaugeMetricFamily(
"onyx_connectors_in_error_total",
"Total number of connectors in repeated error state",
labels=["tenant_id"],
)
per_connector_labels = [
"tenant_id",
"source",
"cc_pair_id",
"connector_name",
]
docs_success_gauge = GaugeMetricFamily(
"onyx_connector_docs_indexed",
"Total new documents indexed (90-day rolling sum) per connector",
labels=per_connector_labels,
)
docs_error_gauge = GaugeMetricFamily(
"onyx_connector_error_count",
"Total number of failed index attempts per connector",
labels=per_connector_labels,
)
now = datetime.now(tz=timezone.utc)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
pairs = get_connector_health_for_metrics(session)
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
docs_by_cc = get_docs_indexed_by_cc_pair(session)
status_counts: dict[str, int] = {}
error_count = 0
for (
cc_id,
status,
in_error,
last_success,
cc_name,
source,
) in pairs:
cc_id_str = str(cc_id)
source_val = source.value
name_val = cc_name or f"cc_pair_{cc_id}"
label_vals = [tid, source_val, cc_id_str, name_val]
if last_success is not None:
# Both `now` and `last_success` are timezone-aware
# (the DB column uses DateTime(timezone=True)),
# so subtraction is safe.
age = (now - last_success).total_seconds()
staleness_gauge.add_metric(label_vals, age)
error_state_gauge.add_metric(
label_vals,
1.0 if in_error else 0.0,
)
if in_error:
error_count += 1
docs_success_gauge.add_metric(
label_vals,
docs_by_cc.get(cc_id, 0),
)
docs_error_gauge.add_metric(
label_vals,
error_counts_by_cc.get(cc_id, 0),
)
status_val = status.value
status_counts[status_val] = status_counts.get(status_val, 0) + 1
for status_val, count in status_counts.items():
by_status_gauge.add_metric([tid, status_val], count)
error_total_gauge.add_metric([tid], error_count)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [
staleness_gauge,
error_state_gauge,
by_status_gauge,
error_total_gauge,
docs_success_gauge,
docs_error_gauge,
]
class RedisHealthCollector(_CachedCollector):
"""Collects Redis server health metrics (memory, clients, etc.)."""
@@ -411,8 +247,6 @@ class RedisHealthCollector(_CachedCollector):
if self._celery_app is None:
return []
from onyx.background.celery.celery_redis import celery_get_broker_client
redis_client = celery_get_broker_client(self._celery_app)
memory_used = GaugeMetricFamily(
@@ -495,7 +329,9 @@ class WorkerHeartbeatMonitor:
},
)
recv.capture(
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
limit=None,
timeout=self._HEARTBEAT_TIMEOUT_SECONDS,
wakeup=True,
)
except Exception:
if self._running:

View File

@@ -6,8 +6,6 @@ Called once by the monitoring celery worker after Redis and DB are ready.
from celery import Celery
from prometheus_client.registry import REGISTRY
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
@@ -21,8 +19,6 @@ logger = setup_logger()
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
_attempt_collector = IndexAttemptCollector()
_connector_collector = ConnectorHealthCollector()
_redis_health_collector = RedisHealthCollector()
_worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
@@ -34,6 +30,9 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
Args:
celery_app: The Celery application instance. Used to obtain a
broker Redis client on each scrape for queue depth metrics.
Note: connector health and index attempt metrics are push-based
(see connector_health_metrics.py) and do not use collectors.
"""
_queue_collector.set_celery_app(celery_app)
_redis_health_collector.set_celery_app(celery_app)
@@ -47,13 +46,8 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
_heartbeat_monitor.start()
_worker_health_collector.set_monitor(_heartbeat_monitor)
_attempt_collector.configure()
_connector_collector.configure()
for collector in (
_queue_collector,
_attempt_collector,
_connector_collector,
_redis_health_collector,
_worker_health_collector,
):

View File

@@ -7,6 +7,9 @@ from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEF
from onyx.background.celery.tasks.beat_schedule import (
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT,
)
from onyx.configs.app_configs import ENABLE_TENANT_WORK_GATING
from onyx.configs.app_configs import TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS
from onyx.configs.app_configs import TENANT_WORK_GATING_TTL_SECONDS
from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
@@ -139,6 +142,87 @@ class OnyxRuntime:
return value
@staticmethod
def _read_tenant_work_gating_flag(axis: str, default: bool) -> bool:
"""Read `runtime:tenant_work_gating:{axis}` from Redis and interpret
it as a bool. Returns `default` if the key is absent or unparseable.
`axis` is either `enabled` (compute the gate) or `enforce` (actually
skip)."""
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:tenant_work_gating:{axis}")
if raw is None:
return default
try:
return cast(bytes, raw).decode().strip().lower() == "true"
except Exception:
return default
@staticmethod
def get_tenant_work_gating_enabled() -> bool:
"""Should we *compute* the work gate? (read the Redis set, log how
many tenants would be skipped). Env-var `ENABLE_TENANT_WORK_GATING`
is the fallback default when no Redis override is set — it acts as
the master switch that turns the feature on in shadow mode."""
return OnyxRuntime._read_tenant_work_gating_flag(
"enabled", default=ENABLE_TENANT_WORK_GATING
)
@staticmethod
def get_tenant_work_gating_enforce() -> bool:
"""Should we *actually skip* tenants not in the work set?
Deliberately Redis-only with a hard-coded default of False: the env
var `ENABLE_TENANT_WORK_GATING` only flips `enabled` (shadow mode),
never `enforce`. Enforcement has to be turned on by an explicit
`runtime:tenant_work_gating:enforce=true` write so ops can't
accidentally skip real tenant traffic by flipping an env flag. Only
meaningful when `get_tenant_work_gating_enabled()` is also True.
"""
return OnyxRuntime._read_tenant_work_gating_flag("enforce", default=False)
@staticmethod
def get_tenant_work_gating_ttl_seconds() -> int:
"""Membership TTL for the `active_tenants` sorted set. Members older
than this are treated as "no recent work" by the gate read path.
Must be > (full-fanout cadence × base task schedule) so self-healing
has time to refresh memberships before they expire."""
default = TENANT_WORK_GATING_TTL_SECONDS
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:tenant_work_gating:ttl_seconds")
if raw is None:
return default
try:
value = int(cast(bytes, raw).decode())
return value if value > 0 else default
except ValueError:
return default
@staticmethod
def get_tenant_work_gating_full_fanout_interval_seconds() -> int:
"""Minimum wall-clock interval between full-fanout cycles. When at
least this many seconds have elapsed since the last bypass, the
generator ignores the gate on its next invocation and dispatches to
every non-gated tenant, letting consumers re-populate the active
set. Schedule-independent so beat drift or backlog can't skew the
self-heal cadence."""
default = TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
raw = r.get(
f"{ONYX_CLOUD_REDIS_RUNTIME}:tenant_work_gating:full_fanout_interval_seconds"
)
if raw is None:
return default
try:
value = int(cast(bytes, raw).decode())
return value if value > 0 else default
except ValueError:
return default
@staticmethod
def get_build_fence_lookup_table_interval() -> int:
"""We maintain an active fence table to make lookups of existing fences efficient.

View File

@@ -10,6 +10,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import PersonaSearchInfo
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.mcp import get_all_mcp_tools_for_server
@@ -113,10 +114,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
def construct_tools(
persona: Persona,
db_session: Session,
emitter: Emitter,
user: User,
llm: LLM,
db_session: Session | None = None,
search_tool_config: SearchToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
file_reader_tool_config: FileReaderToolConfig | None = None,
@@ -131,6 +132,33 @@ def construct_tools(
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
to avoid lazy SQL queries after the session may have been flushed."""
with get_session_with_current_tenant_if_none(db_session) as db_session:
return _construct_tools_impl(
persona=persona,
db_session=db_session,
emitter=emitter,
user=user,
llm=llm,
search_tool_config=search_tool_config,
custom_tool_config=custom_tool_config,
file_reader_tool_config=file_reader_tool_config,
allowed_tool_ids=allowed_tool_ids,
search_usage_forcing_setting=search_usage_forcing_setting,
)
def _construct_tools_impl(
persona: Persona,
db_session: Session,
emitter: Emitter,
user: User,
llm: LLM,
search_tool_config: SearchToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
file_reader_tool_config: FileReaderToolConfig | None = None,
allowed_tool_ids: list[int] | None = None,
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
) -> dict[int, list[Tool]]:
tool_dict: dict[int, list[Tool]] = {}
# Log which tools are attached to the persona for debugging

View File

@@ -26,7 +26,7 @@ aiolimiter==1.2.1
# via voyageai
aiosignal==1.4.0
# via aiohttp
alembic==1.10.4
alembic==1.18.4
amqp==5.3.1
# via kombu
annotated-doc==0.0.4
@@ -299,7 +299,7 @@ h11==0.16.0
# uvicorn
h2==4.3.0
# via httpx
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
# via huggingface-hub
hpack==4.1.0
# via h2
@@ -322,6 +322,7 @@ httpx==0.28.1
# fastmcp
# google-genai
# httpx-oauth
# huggingface-hub
# langfuse
# langsmith
# litellm
@@ -334,7 +335,7 @@ httpx-sse==0.4.3
# cohere
# mcp
hubspot-api-client==11.1.0
huggingface-hub==0.35.3
huggingface-hub==1.10.2
# via tokenizers
humanfriendly==10.0
# via coloredlogs
@@ -408,7 +409,7 @@ kombu==5.5.4
# via celery
kubernetes==31.0.0
# via onyx
langchain-core==1.2.22
langchain-core==1.2.28
langdetect==1.0.9
# via unstructured
langfuse==3.10.0
@@ -589,7 +590,7 @@ platformdirs==4.5.0
# via
# fastmcp
# zeep
playwright==1.55.0
playwright==1.58.0
# via pytest-playwright
pluggy==1.6.0
# via pytest
@@ -784,7 +785,6 @@ requests==2.33.0
# google-api-core
# google-genai
# hubspot-api-client
# huggingface-hub
# jira
# jsonschema-path
# kubernetes
@@ -911,7 +911,7 @@ tiktoken==0.7.0
timeago==1.0.16
tld==0.13.1
# via courlan
tokenizers==0.21.4
tokenizers==0.22.2
# via
# chonkie
# cohere
@@ -933,7 +933,9 @@ tqdm==4.67.1
# unstructured
trafilatura==1.12.2
typer==0.20.0
# via mcp
# via
# huggingface-hub
# mcp
types-awscrt==0.28.4
# via botocore-stubs
types-openpyxl==3.0.4.7

View File

@@ -22,7 +22,7 @@ aiolimiter==1.2.1
# via voyageai
aiosignal==1.4.0
# via aiohttp
alembic==1.10.4
alembic==1.18.4
# via pytest-alembic
annotated-doc==0.0.4
# via fastapi
@@ -82,6 +82,7 @@ click==8.3.1
# via
# black
# litellm
# typer
# uvicorn
cohere==5.6.1
# via onyx
@@ -153,7 +154,7 @@ h11==0.16.0
# httpcore
# uvicorn
hatchling==1.28.0
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
# via huggingface-hub
httpcore==1.0.9
# via httpx
@@ -161,6 +162,7 @@ httpx==0.28.1
# via
# cohere
# google-genai
# huggingface-hub
# litellm
# mcp
# openai
@@ -168,7 +170,7 @@ httpx-sse==0.4.3
# via
# cohere
# mcp
huggingface-hub==0.35.3
huggingface-hub==1.10.2
# via tokenizers
identify==2.6.15
# via pre-commit
@@ -219,6 +221,8 @@ litellm==1.81.6
mako==1.2.4
# via alembic
manygo==0.2.0
markdown-it-py==4.0.0
# via rich
markupsafe==3.0.3
# via
# jinja2
@@ -230,6 +234,8 @@ matplotlib-inline==0.2.1
# ipython
mcp==1.26.0
# via claude-agent-sdk
mdurl==0.1.2
# via markdown-it-py
multidict==6.7.0
# via
# aiobotocore
@@ -340,6 +346,7 @@ pygments==2.20.0
# ipython
# ipython-pygments-lexers
# pytest
# rich
pyjwt==2.12.0
# via mcp
pyparsing==3.2.5
@@ -395,7 +402,6 @@ requests==2.33.0
# via
# cohere
# google-genai
# huggingface-hub
# kubernetes
# requests-oauthlib
# tiktoken
@@ -404,6 +410,8 @@ requests-oauthlib==1.3.1
# via kubernetes
retry==0.9.2
# via onyx
rich==14.2.0
# via typer
rpds-py==0.29.0
# via
# jsonschema
@@ -415,6 +423,8 @@ s3transfer==0.13.1
# via boto3
sentry-sdk==2.14.0
# via onyx
shellingham==1.5.4
# via typer
six==1.17.0
# via
# kubernetes
@@ -442,7 +452,7 @@ tenacity==9.1.2
# voyageai
tiktoken==0.7.0
# via litellm
tokenizers==0.21.4
tokenizers==0.22.2
# via
# cohere
# litellm
@@ -463,6 +473,8 @@ traitlets==5.14.3
# matplotlib-inline
trove-classifiers==2025.12.1.14
# via hatchling
typer==0.20.0
# via huggingface-hub
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
# via types-beautifulsoup4
@@ -500,6 +512,7 @@ typing-extensions==4.15.0
# referencing
# sqlalchemy
# starlette
# typer
# typing-inspection
typing-inspection==0.4.2
# via

View File

@@ -69,6 +69,7 @@ claude-agent-sdk==0.1.19
click==8.3.1
# via
# litellm
# typer
# uvicorn
cohere==5.6.1
# via onyx
@@ -112,7 +113,7 @@ h11==0.16.0
# via
# httpcore
# uvicorn
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
# via huggingface-hub
httpcore==1.0.9
# via httpx
@@ -120,6 +121,7 @@ httpx==0.28.1
# via
# cohere
# google-genai
# huggingface-hub
# litellm
# mcp
# openai
@@ -127,7 +129,7 @@ httpx-sse==0.4.3
# via
# cohere
# mcp
huggingface-hub==0.35.3
huggingface-hub==1.10.2
# via tokenizers
idna==3.11
# via
@@ -156,10 +158,14 @@ kubernetes==31.0.0
# via onyx
litellm==1.81.6
# via onyx
markdown-it-py==4.0.0
# via rich
markupsafe==3.0.3
# via jinja2
mcp==1.26.0
# via claude-agent-sdk
mdurl==0.1.2
# via markdown-it-py
monotonic==1.6
# via posthog
multidict==6.7.0
@@ -217,6 +223,8 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pygments==2.20.0
# via rich
pyjwt==2.12.0
# via mcp
python-dateutil==2.8.2
@@ -247,7 +255,6 @@ requests==2.33.0
# via
# cohere
# google-genai
# huggingface-hub
# kubernetes
# posthog
# requests-oauthlib
@@ -257,6 +264,8 @@ requests-oauthlib==1.3.1
# via kubernetes
retry==0.9.2
# via onyx
rich==14.2.0
# via typer
rpds-py==0.29.0
# via
# jsonschema
@@ -267,6 +276,8 @@ s3transfer==0.13.1
# via boto3
sentry-sdk==2.14.0
# via onyx
shellingham==1.5.4
# via typer
six==1.17.0
# via
# kubernetes
@@ -289,7 +300,7 @@ tenacity==9.1.2
# voyageai
tiktoken==0.7.0
# via litellm
tokenizers==0.21.4
tokenizers==0.22.2
# via
# cohere
# litellm
@@ -297,6 +308,8 @@ tqdm==4.67.1
# via
# huggingface-hub
# openai
typer==0.20.0
# via huggingface-hub
types-requests==2.32.0.20250328
# via cohere
typing-extensions==4.15.0
@@ -313,6 +326,7 @@ typing-extensions==4.15.0
# pydantic-core
# referencing
# starlette
# typer
# typing-inspection
typing-inspection==0.4.2
# via

View File

@@ -78,6 +78,7 @@ click==8.3.1
# click-plugins
# click-repl
# litellm
# typer
# uvicorn
click-didyoumean==0.3.1
# via celery
@@ -116,7 +117,6 @@ filelock==3.20.3
# via
# huggingface-hub
# torch
# transformers
frozenlist==1.8.0
# via
# aiohttp
@@ -135,7 +135,7 @@ h11==0.16.0
# via
# httpcore
# uvicorn
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
# via huggingface-hub
httpcore==1.0.9
# via httpx
@@ -143,6 +143,7 @@ httpx==0.28.1
# via
# cohere
# google-genai
# huggingface-hub
# litellm
# mcp
# openai
@@ -150,7 +151,7 @@ httpx-sse==0.4.3
# via
# cohere
# mcp
huggingface-hub==0.35.3
huggingface-hub==1.10.2
# via
# accelerate
# sentence-transformers
@@ -189,10 +190,14 @@ kubernetes==31.0.0
# via onyx
litellm==1.81.6
# via onyx
markdown-it-py==4.0.0
# via rich
markupsafe==3.0.3
# via jinja2
mcp==1.26.0
# via claude-agent-sdk
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
# via sympy
multidict==6.7.0
@@ -207,6 +212,7 @@ numpy==2.4.1
# accelerate
# scikit-learn
# scipy
# sentence-transformers
# transformers
# voyageai
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
@@ -264,8 +270,6 @@ packaging==24.2
# transformers
parameterized==0.9.0
# via cohere
pillow==12.2.0
# via sentence-transformers
prometheus-client==0.23.1
# via
# onyx
@@ -305,6 +309,8 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pygments==2.20.0
# via rich
pyjwt==2.12.0
# via mcp
python-dateutil==2.8.2
@@ -339,16 +345,16 @@ requests==2.33.0
# via
# cohere
# google-genai
# huggingface-hub
# kubernetes
# requests-oauthlib
# tiktoken
# transformers
# voyageai
requests-oauthlib==1.3.1
# via kubernetes
retry==0.9.2
# via onyx
rich==14.2.0
# via typer
rpds-py==0.29.0
# via
# jsonschema
@@ -367,11 +373,13 @@ scipy==1.16.3
# via
# scikit-learn
# sentence-transformers
sentence-transformers==4.0.2
sentence-transformers==5.4.1
sentry-sdk==2.14.0
# via onyx
setuptools==80.9.0 ; python_full_version >= '3.12'
# via torch
shellingham==1.5.4
# via typer
six==1.17.0
# via
# kubernetes
@@ -398,7 +406,7 @@ threadpoolctl==3.6.0
# via scikit-learn
tiktoken==0.7.0
# via litellm
tokenizers==0.21.4
tokenizers==0.22.2
# via
# cohere
# litellm
@@ -413,10 +421,14 @@ tqdm==4.67.1
# openai
# sentence-transformers
# transformers
transformers==4.53.0
transformers==5.5.4
# via sentence-transformers
triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
typer==0.20.0
# via
# huggingface-hub
# transformers
types-requests==2.32.0.20250328
# via cohere
typing-extensions==4.15.0
@@ -435,6 +447,7 @@ typing-extensions==4.15.0
# sentence-transformers
# starlette
# torch
# typer
# typing-inspection
typing-inspection==0.4.2
# via

View File

@@ -4,6 +4,7 @@ from unittest.mock import patch
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from tests.daily.connectors.google_drive.consts_and_utils import _pick
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
@@ -699,3 +700,43 @@ def test_specific_user_email_shared_with_me(
doc_titles = set(doc.semantic_identifier for doc in output.documents)
assert doc_titles == set(expected)
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_slim_retrieval_does_not_call_permissions_list(
mock_get_api_key: MagicMock, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""retrieve_all_slim_docs() must not call permissions().list for any file.
Pruning only needs file IDs — fetching permissions per file causes O(N) API
calls that time out for tenants with large numbers of externally-owned files.
"""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
include_my_drives=True,
include_files_shared_with_me=False,
shared_folder_urls=None,
shared_drive_urls=None,
my_drive_emails=None,
)
with patch(
"onyx.connectors.google_drive.connector.execute_paginated_retrieval",
wraps=execute_paginated_retrieval,
) as mock_paginated:
for batch in connector.retrieve_all_slim_docs():
pass
permissions_calls = [
c
for c in mock_paginated.call_args_list
if "permissions" in str(c.kwargs.get("retrieval_function", ""))
]
assert (
len(permissions_calls) == 0
), f"permissions().list was called {len(permissions_calls)} time(s) during pruning"

View File

@@ -12,6 +12,7 @@ from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
_ITERATION_LIMIT = 100_000
@@ -141,13 +142,15 @@ def load_all_from_connector(
def to_sections(
documents: list[Document],
) -> Iterator[TextSection | ImageSection]:
) -> Iterator[TextSection | ImageSection | TabularSection]:
for doc in documents:
for section in doc.sections:
yield section
def to_text_sections(sections: Iterator[TextSection | ImageSection]) -> Iterator[str]:
def to_text_sections(
sections: Iterator[TextSection | ImageSection | TabularSection],
) -> Iterator[str]:
for section in sections:
if isinstance(section, TextSection):
yield section.text

View File

@@ -0,0 +1,159 @@
"""Tests for the tenant work-gating Redis helpers.
Requires a running Redis instance. Run with::
python -m dotenv -f .vscode/.env run -- pytest \
backend/tests/external_dependency_unit/tenant_work_gating/test_tenant_work_gating.py
"""
import time
from collections.abc import Generator
from unittest.mock import patch
import pytest
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.redis import redis_tenant_work_gating as twg
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_tenant_work_gating import _SET_KEY
from onyx.redis.redis_tenant_work_gating import cleanup_expired
from onyx.redis.redis_tenant_work_gating import get_active_tenants
from onyx.redis.redis_tenant_work_gating import mark_tenant_active
@pytest.fixture(autouse=True)
def _multi_tenant_true() -> Generator[None, None, None]:
"""Force MULTI_TENANT=True for the helper module so public functions are
not no-ops during tests."""
with patch.object(twg, "MULTI_TENANT", True):
yield
@pytest.fixture(autouse=True)
def _clean_set() -> Generator[None, None, None]:
"""Clear the active_tenants sorted set before and after each test."""
client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
client.delete(_SET_KEY)
yield
client.delete(_SET_KEY)
def test_mark_adds_tenant_to_set() -> None:
mark_tenant_active("tenant_a")
assert get_active_tenants(ttl_seconds=60) == {"tenant_a"}
def test_mark_refreshes_timestamp() -> None:
"""ZADD overwrites the score on existing members. Without a refresh,
reading with a TTL that excludes the first write should return empty;
after a second mark_tenant_active at a newer timestamp, the same TTL
read should include the tenant. Pins `_now_ms` so the test is
deterministic."""
base_ms = int(time.time() * 1000)
# First write at t=0.
with patch.object(twg, "_now_ms", return_value=base_ms):
mark_tenant_active("tenant_a")
# Read 5s later with a 1s TTL — first write is outside the window.
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
assert get_active_tenants(ttl_seconds=1) == set()
# Refresh at t=5s.
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
mark_tenant_active("tenant_a")
# Read at t=5s with a 1s TTL — refreshed write is inside the window.
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
assert get_active_tenants(ttl_seconds=1) == {"tenant_a"}
def test_get_active_tenants_filters_by_ttl() -> None:
"""Tenant marked in the past, read with a TTL short enough to exclude it."""
# Pin _now_ms so the write happens at t=0 and the read cutoff is
# well after that.
base_ms = int(time.time() * 1000)
with patch.object(twg, "_now_ms", return_value=base_ms):
mark_tenant_active("tenant_old")
# Read 5 seconds later with a 1-second TTL — tenant_old is outside.
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
assert get_active_tenants(ttl_seconds=1) == set()
# Read 5 seconds later with a 10-second TTL — tenant_old is inside.
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
assert get_active_tenants(ttl_seconds=10) == {"tenant_old"}
def test_get_active_tenants_multiple_members() -> None:
mark_tenant_active("tenant_a")
mark_tenant_active("tenant_b")
mark_tenant_active("tenant_c")
assert get_active_tenants(ttl_seconds=60) == {"tenant_a", "tenant_b", "tenant_c"}
def test_get_active_tenants_empty_set() -> None:
"""Genuinely-empty set returns an empty set (not None)."""
assert get_active_tenants(ttl_seconds=60) == set()
def test_get_active_tenants_returns_none_on_redis_error() -> None:
"""Callers need to distinguish Redis failure from "no tenants active" so
they can fail open. Simulate failure by patching the client to raise."""
from unittest.mock import MagicMock
failing_client = MagicMock()
failing_client.zrangebyscore.side_effect = RuntimeError("simulated outage")
with patch.object(twg, "_client", return_value=failing_client):
assert get_active_tenants(ttl_seconds=60) is None
def test_get_active_tenants_returns_none_in_single_tenant_mode() -> None:
"""Single-tenant mode returns None so callers can skip the gate entirely
(same fail-open handling as Redis unavailability)."""
with patch.object(twg, "MULTI_TENANT", False):
assert get_active_tenants(ttl_seconds=60) is None
def test_cleanup_expired_removes_only_stale_members() -> None:
"""Seed one stale and one fresh member directly; cleanup should drop only
the stale one."""
now_ms = int(time.time() * 1000)
client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
client.zadd(_SET_KEY, mapping={"tenant_old": now_ms - 10 * 60 * 1000})
client.zadd(_SET_KEY, mapping={"tenant_new": now_ms})
removed = cleanup_expired(ttl_seconds=60)
assert removed == 1
assert get_active_tenants(ttl_seconds=60 * 60) == {"tenant_new"}
def test_cleanup_expired_empty_set_noop() -> None:
assert cleanup_expired(ttl_seconds=60) == 0
def test_noop_when_multi_tenant_false() -> None:
with patch.object(twg, "MULTI_TENANT", False):
mark_tenant_active("tenant_a")
assert get_active_tenants(ttl_seconds=60) is None
assert cleanup_expired(ttl_seconds=60) == 0
# Verify nothing was written while MULTI_TENANT was False.
assert get_active_tenants(ttl_seconds=60) == set()
def test_rendered_key_is_cloud_prefixed() -> None:
"""Exercises TenantRedis auto-prefixing on sorted-set ops. The rendered
Redis key should be `cloud:active_tenants`, not bare `active_tenants`."""
mark_tenant_active("tenant_a")
from onyx.redis.redis_pool import RedisPool
raw = RedisPool().get_raw_client()
assert raw.zscore("cloud:active_tenants", "tenant_a") is not None
assert raw.zscore("active_tenants", "tenant_a") is None

View File

@@ -38,38 +38,41 @@ class TestAddMemory:
def test_add_memory_creates_row(self, db_session: Session, test_user: User) -> None:
"""Verify that add_memory inserts a new Memory row."""
user_id = test_user.id
memory = add_memory(
memory_id = add_memory(
user_id=user_id,
memory_text="User prefers dark mode",
db_session=db_session,
)
assert memory.id is not None
assert memory.user_id == user_id
assert memory.memory_text == "User prefers dark mode"
assert memory_id is not None
# Verify it persists
fetched = db_session.get(Memory, memory.id)
fetched = db_session.get(Memory, memory_id)
assert fetched is not None
assert fetched.user_id == user_id
assert fetched.memory_text == "User prefers dark mode"
def test_add_multiple_memories(self, db_session: Session, test_user: User) -> None:
"""Verify that multiple memories can be added for the same user."""
user_id = test_user.id
m1 = add_memory(
m1_id = add_memory(
user_id=user_id,
memory_text="Favorite color is blue",
db_session=db_session,
)
m2 = add_memory(
m2_id = add_memory(
user_id=user_id,
memory_text="Works in engineering",
db_session=db_session,
)
assert m1.id != m2.id
assert m1.memory_text == "Favorite color is blue"
assert m2.memory_text == "Works in engineering"
assert m1_id != m2_id
fetched_m1 = db_session.get(Memory, m1_id)
fetched_m2 = db_session.get(Memory, m2_id)
assert fetched_m1 is not None
assert fetched_m2 is not None
assert fetched_m1.memory_text == "Favorite color is blue"
assert fetched_m2.memory_text == "Works in engineering"
class TestUpdateMemoryAtIndex:
@@ -82,15 +85,17 @@ class TestUpdateMemoryAtIndex:
add_memory(user_id=user_id, memory_text="Memory 1", db_session=db_session)
add_memory(user_id=user_id, memory_text="Memory 2", db_session=db_session)
updated = update_memory_at_index(
updated_id = update_memory_at_index(
user_id=user_id,
index=1,
new_text="Updated Memory 1",
db_session=db_session,
)
assert updated is not None
assert updated.memory_text == "Updated Memory 1"
assert updated_id is not None
fetched = db_session.get(Memory, updated_id)
assert fetched is not None
assert fetched.memory_text == "Updated Memory 1"
def test_update_memory_at_out_of_range_index(
self, db_session: Session, test_user: User
@@ -167,7 +172,7 @@ class TestMemoryCap:
assert len(rows_before) == MAX_MEMORIES_PER_USER
# Add one more — should evict the oldest
new_memory = add_memory(
new_memory_id = add_memory(
user_id=user_id,
memory_text="New memory after cap",
db_session=db_session,
@@ -181,7 +186,7 @@ class TestMemoryCap:
# Oldest ("Memory 0") should be gone; "Memory 1" is now the oldest
assert rows_after[0].memory_text == "Memory 1"
# Newest should be the one we just added
assert rows_after[-1].id == new_memory.id
assert rows_after[-1].id == new_memory_id
assert rows_after[-1].memory_text == "New memory after cap"
@@ -221,22 +226,26 @@ class TestGetMemoriesWithUserId:
user_id = test_user_no_memories.id
# Add a memory
memory = add_memory(
memory_id = add_memory(
user_id=user_id,
memory_text="Memory with use_memories off",
db_session=db_session,
)
assert memory.memory_text == "Memory with use_memories off"
fetched = db_session.get(Memory, memory_id)
assert fetched is not None
assert fetched.memory_text == "Memory with use_memories off"
# Update that memory
updated = update_memory_at_index(
updated_id = update_memory_at_index(
user_id=user_id,
index=0,
new_text="Updated memory with use_memories off",
db_session=db_session,
)
assert updated is not None
assert updated.memory_text == "Updated memory with use_memories off"
assert updated_id is not None
fetched_updated = db_session.get(Memory, updated_id)
assert fetched_updated is not None
assert fetched_updated.memory_text == "Updated memory with use_memories off"
# Verify get_memories returns the updated memory
context = get_memories(test_user_no_memories, db_session)

View File

@@ -12,7 +12,7 @@ from onyx.db.models import DocumentByConnectorCredentialPair
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import DATestAPIKey
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import SimpleTestDocument
from tests.integration.common_utils.vespa import vespa_fixture

View File

@@ -62,6 +62,7 @@ class DocumentSetManager:
) -> bool:
doc_set_update_request = {
"id": document_set.id,
"name": document_set.name,
"description": document_set.description,
"cc_pair_ids": document_set.cc_pair_ids,
"is_public": document_set.is_public,

View File

@@ -14,7 +14,6 @@ from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.image_generation import (
ImageGenerationConfigManager,
@@ -196,6 +195,9 @@ def image_generation_config(
@pytest.fixture
def document_builder(admin_user: DATestUser) -> DocumentBuilderType:
# HACK: Avoid importing generated OpenAPI client modules unless this fixture is used.
from tests.integration.common_utils.managers.cc_pair import CCPairManager
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)

View File

@@ -1,4 +1,4 @@
FROM python:3.11.7-slim-bookworm
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
WORKDIR /app

View File

@@ -108,12 +108,12 @@ def current_head_rev() -> str:
["alembic", "heads", "--resolve-dependencies"],
cwd=_BACKEND_DIR,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stderr=subprocess.PIPE,
text=True,
)
assert (
result.returncode == 0
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}"
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}\n{result.stderr}"
# Output looks like "d5c86e2c6dc6 (head)\n"
rev = result.stdout.strip().split()[0]
assert len(rev) > 0

View File

@@ -1,3 +1,5 @@
from uuid import uuid4
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import APIKeyManager
@@ -159,3 +161,58 @@ def test_removing_connector(
doc_set_names=[],
doc_creating_user=admin_user,
)
def test_renaming_document_set(
reset: None, # noqa: ARG001
vespa_client: vespa_fixture,
) -> None:
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
cc_pair = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
cc_pair.documents = DocumentManager.seed_dummy_docs(
cc_pair=cc_pair,
num_docs=NUM_DOCS,
api_key=api_key,
)
original_name = f"original_doc_set_{uuid4()}"
doc_set = DocumentSetManager.create(
name=original_name,
cc_pair_ids=[cc_pair.id],
user_performing_action=admin_user,
)
DocumentSetManager.wait_for_sync(user_performing_action=admin_user)
DocumentSetManager.verify(
document_set=doc_set,
user_performing_action=admin_user,
)
new_name = f"renamed_doc_set_{uuid4()}"
doc_set.name = new_name
DocumentSetManager.edit(
doc_set,
user_performing_action=admin_user,
)
DocumentSetManager.wait_for_sync(user_performing_action=admin_user)
DocumentSetManager.verify(
document_set=doc_set,
user_performing_action=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair,
doc_set_names=[new_name],
doc_creating_user=admin_user,
)

View File

@@ -0,0 +1,83 @@
"""
Integration tests verifying the knowledge_sources field on MinimalPersonaSnapshot.
The GET /persona endpoint returns MinimalPersonaSnapshot, which includes a
knowledge_sources list derived from the persona's document sets, hierarchy
nodes, attached documents, and user files. These tests verify that the
field is populated correctly.
"""
import requests
from onyx.configs.constants import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.file import FileManager
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.test_file_utils import create_test_text_file
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
def _get_minimal_persona(
persona_id: int,
user: DATestUser,
) -> dict:
"""Fetch personas from the list endpoint and find the one with the given id."""
response = requests.get(
f"{API_SERVER_URL}/persona",
params={"persona_ids": persona_id},
headers=user.headers,
)
response.raise_for_status()
personas = response.json()
matches = [p for p in personas if p["id"] == persona_id]
assert (
len(matches) == 1
), f"Expected 1 persona with id={persona_id}, got {len(matches)}"
return matches[0]
def test_persona_with_user_files_includes_user_file_source(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""When a persona has user files attached, knowledge_sources includes 'user_file'."""
text_file = create_test_text_file("test content for knowledge source verification")
file_descriptors, error = FileManager.upload_files(
files=[("test_ks.txt", text_file)],
user_performing_action=admin_user,
)
assert not error, f"File upload failed: {error}"
user_file_id = file_descriptors[0]["user_file_id"] or ""
persona = PersonaManager.create(
user_performing_action=admin_user,
name="KS User File Agent",
description="Agent with user files for knowledge_sources test",
system_prompt="You are a helpful assistant.",
user_file_ids=[user_file_id],
)
minimal = _get_minimal_persona(persona.id, admin_user)
assert (
DocumentSource.USER_FILE.value in minimal["knowledge_sources"]
), f"Expected 'user_file' in knowledge_sources, got: {minimal['knowledge_sources']}"
def test_persona_without_user_files_excludes_user_file_source(
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""When a persona has no user files, knowledge_sources should not contain 'user_file'."""
persona = PersonaManager.create(
user_performing_action=admin_user,
name="KS No Files Agent",
description="Agent without files for knowledge_sources test",
system_prompt="You are a helpful assistant.",
)
minimal = _get_minimal_persona(persona.id, admin_user)
assert (
DocumentSource.USER_FILE.value not in minimal["knowledge_sources"]
), f"Unexpected 'user_file' in knowledge_sources: {minimal['knowledge_sources']}"

View File

@@ -301,7 +301,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -332,7 +331,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -363,7 +361,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -391,7 +388,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -423,7 +419,6 @@ class TestRunModels:
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -456,7 +451,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -497,7 +491,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
@@ -519,7 +512,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
@@ -542,7 +534,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
@@ -596,7 +587,6 @@ class TestRunModels:
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
@@ -653,7 +643,6 @@ class TestRunModels:
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
@@ -706,7 +695,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
@@ -736,7 +724,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",

View File

@@ -0,0 +1,88 @@
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
from sentry_sdk.types import Event
import onyx.configs.sentry as sentry_module
from onyx.configs.sentry import _add_instance_tags
def _event(data: dict) -> Event:
"""Helper to create a Sentry Event from a plain dict for testing."""
return cast(Event, data)
def _reset_state() -> None:
"""Reset the module-level resolved flag between tests."""
sentry_module._instance_id_resolved = False
class TestAddInstanceTags:
def setup_method(self) -> None:
_reset_state()
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
@patch("sentry_sdk.set_tag")
def test_first_event_sets_instance_id(
self, mock_set_tag: MagicMock, mock_uuid: MagicMock
) -> None:
result = _add_instance_tags(_event({"message": "test error"}), {})
assert result is not None
assert result["tags"]["instance_id"] == "test-uuid-1234"
mock_set_tag.assert_called_once_with("instance_id", "test-uuid-1234")
mock_uuid.assert_called_once()
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
@patch("sentry_sdk.set_tag")
def test_second_event_skips_resolution(
self, _mock_set_tag: MagicMock, mock_uuid: MagicMock
) -> None:
_add_instance_tags(_event({"message": "first"}), {})
result = _add_instance_tags(_event({"message": "second"}), {})
assert result is not None
assert "tags" not in result # second event not modified
mock_uuid.assert_called_once() # only resolved once
@patch(
"onyx.utils.telemetry.get_or_generate_uuid",
side_effect=Exception("DB unavailable"),
)
@patch("sentry_sdk.set_tag")
def test_resolution_failure_still_returns_event(
self, _mock_set_tag: MagicMock, _mock_uuid: MagicMock
) -> None:
result = _add_instance_tags(_event({"message": "test error"}), {})
assert result is not None
assert result["message"] == "test error"
assert "tags" not in result or "instance_id" not in result.get("tags", {})
@patch(
"onyx.utils.telemetry.get_or_generate_uuid",
side_effect=Exception("DB unavailable"),
)
@patch("sentry_sdk.set_tag")
def test_resolution_failure_retries_on_next_event(
self, _mock_set_tag: MagicMock, mock_uuid: MagicMock
) -> None:
"""If resolution fails (e.g. DB not ready), retry on the next event."""
_add_instance_tags(_event({"message": "first"}), {})
_add_instance_tags(_event({"message": "second"}), {})
assert mock_uuid.call_count == 2 # retried on second event
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
@patch("sentry_sdk.set_tag")
def test_preserves_existing_tags(
self, _mock_set_tag: MagicMock, _mock_uuid: MagicMock
) -> None:
result = _add_instance_tags(
_event({"message": "test", "tags": {"existing": "tag"}}), {}
)
assert result is not None
assert result["tags"]["existing"] == "tag"
assert result["tags"]["instance_id"] == "test-uuid-1234"

View File

@@ -8,14 +8,23 @@ from unittest.mock import patch
import pytest
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.canvas.connector import _in_time_window
from onyx.connectors.canvas.connector import _parse_canvas_dt
from onyx.connectors.canvas.connector import _unix_to_canvas_time
from onyx.connectors.canvas.connector import CanvasConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.canvas.connector import CanvasConnectorCheckpoint
from onyx.connectors.canvas.connector import CanvasStage
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
@@ -111,6 +120,56 @@ def _mock_response(
return resp
def _make_url_dispatcher(
courses: list[dict[str, Any]] | None = None,
pages: list[dict[str, Any]] | None = None,
assignments: list[dict[str, Any]] | None = None,
announcements: list[dict[str, Any]] | None = None,
page_error: bool = False,
) -> Any:
"""Return a callable that dispatches mock responses based on the request URL.
Meant to be assigned to ``mock_requests.get.side_effect``.
"""
api_prefix = f"{FAKE_BASE_URL}/api/v1"
def _dispatcher(url: str, **_kwargs: Any) -> MagicMock:
if page_error:
return _mock_response(500, {})
if url == f"{api_prefix}/courses":
return _mock_response(json_data=courses or [])
if "/pages" in url:
return _mock_response(json_data=pages or [])
if "/assignments" in url:
return _mock_response(json_data=assignments or [])
if "announcements" in url:
return _mock_response(json_data=announcements or [])
return _mock_response(json_data=[])
return _dispatcher
def _run_checkpoint(
connector: CanvasConnector,
checkpoint: CanvasConnectorCheckpoint,
start: float = 0.0,
end: float = datetime(2099, 1, 1, tzinfo=timezone.utc).timestamp(),
) -> tuple[
list[Document | HierarchyNode | ConnectorFailure], CanvasConnectorCheckpoint
]:
"""Run load_from_checkpoint once and collect yielded items + returned checkpoint."""
gen = connector.load_from_checkpoint(start, end, checkpoint)
items: list[Document | HierarchyNode | ConnectorFailure] = []
new_checkpoint: CanvasConnectorCheckpoint | None = None
try:
while True:
items.append(next(gen))
except StopIteration as e:
new_checkpoint = e.value
assert new_checkpoint is not None
return items, new_checkpoint
# ---------------------------------------------------------------------------
# CanvasApiClient.__init__ tests
# ---------------------------------------------------------------------------
@@ -269,15 +328,6 @@ class TestGet:
assert exc_info.value.status_code == 404
@patch("onyx.connectors.canvas.client.rl_requests")
def test_raises_on_429(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(429, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
assert exc_info.value.status_code == 429
@patch("onyx.connectors.canvas.client.rl_requests")
def test_skips_params_when_using_full_url(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
@@ -454,6 +504,149 @@ class TestPaginate:
assert pages == []
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_error_dict(
self, mock_requests: MagicMock
) -> None:
"""Shape 1: {"error": {"message": "Not authorized"}}"""
mock_requests.get.return_value = _mock_response(
403, {"error": {"message": "Not authorized"}}
)
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
with pytest.raises(OnyxError) as exc_info:
client.get("courses")
result = exc_info.value.detail
expected = "Not authorized"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_error_string(
self, mock_requests: MagicMock
) -> None:
"""Shape 2: {"error": "Invalid access token"}"""
mock_requests.get.return_value = _mock_response(
401, {"error": "Invalid access token"}
)
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
with pytest.raises(OnyxError) as exc_info:
client.get("courses")
result = exc_info.value.detail
expected = "Invalid access token"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_errors_list(
self, mock_requests: MagicMock
) -> None:
"""Shape 3: {"errors": [{"message": "Invalid query"}]}"""
mock_requests.get.return_value = _mock_response(
400, {"errors": [{"message": "Invalid query"}]}
)
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
with pytest.raises(OnyxError) as exc_info:
client.get("courses")
result = exc_info.value.detail
expected = "Invalid query"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_dict_takes_priority_over_errors_list(
self, mock_requests: MagicMock
) -> None:
"""When both error shapes are present, error dict wins."""
mock_requests.get.return_value = _mock_response(
403, {"error": "Specific error", "errors": [{"message": "Generic"}]}
)
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
with pytest.raises(OnyxError) as exc_info:
client.get("courses")
result = exc_info.value.detail
expected = "Specific error"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_falls_back_to_reason_when_no_json_message(
self, mock_requests: MagicMock
) -> None:
"""Empty error body falls back to response.reason."""
mock_requests.get.return_value = _mock_response(500, {})
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
with pytest.raises(OnyxError) as exc_info:
client.get("courses")
result = exc_info.value.detail
expected = "Error" # from _mock_response's reason for >= 300
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_invalid_json_on_success_raises(self, mock_requests: MagicMock) -> None:
"""Invalid JSON on a 2xx response raises OnyxError."""
resp = MagicMock()
resp.status_code = 200
resp.json.side_effect = ValueError("No JSON")
resp.headers = {"Link": ""}
mock_requests.get.return_value = resp
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
with pytest.raises(OnyxError, match="Invalid JSON"):
client.get("courses")
@patch("onyx.connectors.canvas.client.rl_requests")
def test_invalid_json_on_error_falls_back_to_reason(
self, mock_requests: MagicMock
) -> None:
"""Invalid JSON on a 4xx response falls back to response.reason."""
resp = MagicMock()
resp.status_code = 500
resp.reason = "Internal Server Error"
resp.json.side_effect = ValueError("No JSON")
resp.headers = {"Link": ""}
mock_requests.get.return_value = resp
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
with pytest.raises(OnyxError) as exc_info:
client.get("courses")
result = exc_info.value.detail
expected = "Internal Server Error"
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient._parse_next_link tests
@@ -588,6 +781,16 @@ class TestConnectorUrlNormalization:
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_insufficient_permissions(
self, mock_requests: MagicMock
) -> None:
mock_requests.get.return_value = _mock_response(403, {})
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(InsufficientPermissionsError):
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
# ---------------------------------------------------------------------------
# CanvasConnector — document conversion
@@ -766,10 +969,6 @@ class TestValidateConnectorSettings:
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
@@ -874,3 +1073,652 @@ class TestListAnnouncements:
result = connector._list_announcements(course_id=1)
assert result == []
class TestCheckpoint:
def test_build_dummy_checkpoint(self) -> None:
connector = _build_connector()
cp = connector.build_dummy_checkpoint()
assert cp.has_more is True
assert cp.course_ids == []
assert cp.current_course_index == 0
assert cp.stage == CanvasStage.PAGES
def test_validate_checkpoint_json(self) -> None:
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1, 2],
current_course_index=1,
stage=CanvasStage.ASSIGNMENTS,
)
json_str = cp.model_dump_json()
restored = connector.validate_checkpoint_json(json_str)
assert restored.course_ids == [1, 2]
assert restored.current_course_index == 1
assert restored.stage == CanvasStage.ASSIGNMENTS
assert restored.has_more is True
# ---------------------------------------------------------------------------
# load_from_checkpoint tests
# ---------------------------------------------------------------------------
class TestLoadFromCheckpoint:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_first_call_materializes_courses(self, mock_requests: MagicMock) -> None:
"""First call should populate course_ids and yield no documents."""
mock_requests.get.side_effect = _make_url_dispatcher(
courses=[_mock_course(1), _mock_course(2, "Data Structures", "CS201")]
)
connector = _build_connector()
cp = connector.build_dummy_checkpoint()
items, new_cp = _run_checkpoint(connector, cp)
assert items == []
assert new_cp.course_ids == [1, 2]
assert new_cp.current_course_index == 0
assert new_cp.stage == CanvasStage.PAGES
assert new_cp.has_more is True
@patch("onyx.connectors.canvas.client.rl_requests")
def test_processes_pages_stage(self, mock_requests: MagicMock) -> None:
"""Pages stage yields page documents within the time window."""
mock_requests.get.side_effect = _make_url_dispatcher(
pages=[_mock_page(10, "Syllabus", "2025-06-15T12:00:00Z")]
)
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.PAGES,
)
start = datetime(2025, 6, 1, 0, 0, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, 0, 0, tzinfo=timezone.utc).timestamp()
items, new_cp = _run_checkpoint(connector, cp, start, end)
expected_count = 1
expected_id = "canvas-page-1-10"
assert len(items) == expected_count
assert isinstance(items[0], Document)
assert items[0].id == expected_id
assert new_cp.stage == CanvasStage.ASSIGNMENTS
@patch("onyx.connectors.canvas.client.rl_requests")
def test_advances_through_all_stages(self, mock_requests: MagicMock) -> None:
"""Calling checkpoint 3 times advances pages -> assignments -> announcements -> next course."""
page = _mock_page(10, updated_at="2025-06-15T12:00:00Z")
assignment = _mock_assignment(20, updated_at="2025-06-15T12:00:00Z")
announcement = _mock_announcement(30, posted_at="2025-06-15T12:00:00Z")
mock_requests.get.side_effect = _make_url_dispatcher(
pages=[page], assignments=[assignment], announcements=[announcement]
)
connector = _build_connector()
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.PAGES,
)
# Stage 1: pages
items1, cp = _run_checkpoint(connector, cp, start, end)
assert cp.stage == CanvasStage.ASSIGNMENTS
assert len(items1) == 1
# Stage 2: assignments
mock_requests.get.side_effect = _make_url_dispatcher(assignments=[assignment])
items2, cp = _run_checkpoint(connector, cp, start, end)
assert cp.stage == CanvasStage.ANNOUNCEMENTS
assert len(items2) == 1
# Stage 3: announcements -> advances course index
mock_requests.get.side_effect = _make_url_dispatcher(
announcements=[announcement]
)
items3, cp = _run_checkpoint(connector, cp, start, end)
assert cp.current_course_index == 1
assert cp.stage == CanvasStage.PAGES
assert cp.has_more is False
@patch("onyx.connectors.canvas.client.rl_requests")
def test_filters_by_time_window(self, mock_requests: MagicMock) -> None:
"""Only documents within (start, end] are yielded."""
old_page = _mock_page(10, updated_at="2025-01-01T00:00:00Z")
new_page = _mock_page(11, title="New Page", updated_at="2025-06-15T12:00:00Z")
mock_requests.get.side_effect = _make_url_dispatcher(pages=[new_page, old_page])
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.PAGES,
)
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
items, _ = _run_checkpoint(connector, cp, start, end)
expected_count = 1
expected_id = "canvas-page-1-11"
assert len(items) == expected_count
assert isinstance(items[0], Document)
assert items[0].id == expected_id
@patch("onyx.connectors.canvas.client.rl_requests")
def test_skips_announcement_without_posted_at(
self, mock_requests: MagicMock
) -> None:
announcement = _mock_announcement()
announcement["posted_at"] = None
mock_requests.get.side_effect = _make_url_dispatcher(
announcements=[announcement]
)
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.ANNOUNCEMENTS,
)
items, _ = _run_checkpoint(connector, cp)
assert len(items) == 0
def test_stage_failure_advances_stage_and_yields_failure(self) -> None:
"""A 500 on a stage fetch yields a stage-level ConnectorFailure and
advances to the next stage, so the framework doesn't loop on the
same failing state forever."""
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1, 2],
current_course_index=0,
stage=CanvasStage.PAGES,
)
with patch.object(
connector,
"_fetch_stage_page",
side_effect=OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"boom",
status_code_override=500,
),
):
items, new_cp = _run_checkpoint(connector, cp)
expected_entity_id = "canvas-pages-1"
assert len(items) == 1
assert isinstance(items[0], ConnectorFailure)
assert items[0].failed_entity is not None
assert items[0].failed_entity.entity_id == expected_entity_id
assert new_cp.stage == CanvasStage.ASSIGNMENTS
assert new_cp.current_course_index == 0
assert new_cp.next_url is None
assert new_cp.has_more is True
def test_course_404_advances_course_and_yields_failure(self) -> None:
"""A 404 on a stage fetch means the whole course is inaccessible —
yield a course-level ConnectorFailure and skip to the next course
instead of burning API calls on every stage of a missing course."""
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1, 2],
current_course_index=0,
stage=CanvasStage.PAGES,
)
with patch.object(
connector,
"_fetch_stage_page",
side_effect=OnyxError(
OnyxErrorCode.NOT_FOUND,
"course gone",
status_code_override=404,
),
):
items, new_cp = _run_checkpoint(connector, cp)
expected_entity_id = "canvas-course-1"
expected_next_course_index = 1
assert len(items) == 1
assert isinstance(items[0], ConnectorFailure)
assert items[0].failed_entity is not None
assert items[0].failed_entity.entity_id == expected_entity_id
assert new_cp.current_course_index == expected_next_course_index
assert new_cp.stage == CanvasStage.PAGES
assert new_cp.next_url is None
assert new_cp.has_more is True
def test_fatal_auth_failure_during_stage_fetch_propagates(self) -> None:
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.PAGES,
)
with patch("onyx.connectors.canvas.client.rl_requests") as mock_requests:
mock_requests.get.return_value = _mock_response(401, {})
with pytest.raises(CredentialExpiredError):
_run_checkpoint(connector, cp)
def test_security_failure_during_stage_fetch_propagates(self) -> None:
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.PAGES,
)
with patch.object(
connector,
"_fetch_stage_page",
side_effect=OnyxError(OnyxErrorCode.BAD_GATEWAY, "bad next link"),
):
with pytest.raises(OnyxError, match="bad next link"):
_run_checkpoint(connector, cp)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_per_document_conversion_failure_yields_connector_failure(
self, mock_requests: MagicMock
) -> None:
"""Bad data for one page yields ConnectorFailure, doesn't stop processing."""
bad_page = {
"page_id": 10,
"url": "test",
"title": "Test",
"body": None,
"created_at": "2025-06-15T12:00:00Z",
"updated_at": "bad-date",
}
mock_requests.get.side_effect = _make_url_dispatcher(pages=[bad_page])
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.PAGES,
)
items, new_cp = _run_checkpoint(connector, cp)
assert len(items) == 1
assert isinstance(items[0], ConnectorFailure)
assert new_cp.stage == CanvasStage.ASSIGNMENTS
@patch("onyx.connectors.canvas.client.rl_requests")
def test_all_courses_done_sets_has_more_false(
self, mock_requests: MagicMock
) -> None:
mock_requests.get.side_effect = _make_url_dispatcher()
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True, course_ids=[1], current_course_index=1
)
items, new_cp = _run_checkpoint(connector, cp)
assert items == []
assert new_cp.has_more is False
def test_invalid_stage_raises_value_error(self) -> None:
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.PAGES,
)
cp.stage = "invalid" # type: ignore[assignment]
with pytest.raises(ValueError, match="Invalid checkpoint stage"):
_run_checkpoint(connector, cp)
# ---------------------------------------------------------------------------
# load_from_checkpoint_with_perm_sync tests
# ---------------------------------------------------------------------------
class TestLoadFromCheckpointWithPermSync:
@patch("onyx.connectors.canvas.connector.get_course_permissions")
@patch("onyx.connectors.canvas.client.rl_requests")
def test_documents_have_external_access(
self, mock_requests: MagicMock, mock_perms: MagicMock
) -> None:
"""load_from_checkpoint_with_perm_sync attaches ExternalAccess to documents."""
expected_access = ExternalAccess(
external_user_emails={"student@school.edu"},
external_user_group_ids=set(),
is_public=False,
)
mock_perms.return_value = expected_access
mock_requests.get.side_effect = _make_url_dispatcher(
pages=[_mock_page(10, "Syllabus", "2025-06-15T12:00:00Z")]
)
connector = _build_connector()
cp = CanvasConnectorCheckpoint(
has_more=True,
course_ids=[1],
current_course_index=0,
stage=CanvasStage.PAGES,
)
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
gen = connector.load_from_checkpoint_with_perm_sync(start, end, cp)
items: list[Document | HierarchyNode | ConnectorFailure] = []
new_cp: CanvasConnectorCheckpoint | None = None
try:
while True:
items.append(next(gen))
except StopIteration as e:
new_cp = e.value
assert new_cp is not None
assert len(items) == 1
assert isinstance(items[0], Document)
assert items[0].external_access == expected_access
assert new_cp.stage == CanvasStage.ASSIGNMENTS
mock_perms.assert_called_once()
# ---------------------------------------------------------------------------
# Helper function tests
# ---------------------------------------------------------------------------
class TestParseCanvasDt:
def test_z_suffix_parsed_as_utc(self) -> None:
result = _parse_canvas_dt("2025-06-15T12:00:00Z")
expected = datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_plus_offset_parsed_as_utc(self) -> None:
result = _parse_canvas_dt("2025-06-15T12:00:00+00:00")
expected = datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
assert result == expected
def test_result_is_timezone_aware(self) -> None:
result = _parse_canvas_dt("2025-01-01T00:00:00Z")
assert result.tzinfo is not None
class TestUnixToCanvasTime:
def test_known_epoch_produces_expected_string(self) -> None:
epoch = datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp()
result = _unix_to_canvas_time(epoch)
assert result == "2025-06-15T12:00:00Z"
def test_round_trips_with_parse_canvas_dt(self) -> None:
epoch = datetime(2025, 3, 10, 8, 30, 0, tzinfo=timezone.utc).timestamp()
result = _parse_canvas_dt(_unix_to_canvas_time(epoch))
expected = datetime(2025, 3, 10, 8, 30, 0, tzinfo=timezone.utc)
assert result == expected
class TestInTimeWindow:
def test_inside_window(self) -> None:
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
result = _in_time_window("2025-06-15T12:00:00Z", start, end)
assert result is True
def test_before_window(self) -> None:
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
result = _in_time_window("2025-05-01T12:00:00Z", start, end)
assert result is False
def test_after_window(self) -> None:
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
result = _in_time_window("2025-07-15T12:00:00Z", start, end)
assert result is False
def test_start_boundary_is_exclusive(self) -> None:
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
result = _in_time_window("2025-06-01T00:00:00Z", start, end)
assert result is False
def test_end_boundary_is_inclusive(self) -> None:
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
result = _in_time_window("2025-06-30T00:00:00Z", start, end)
assert result is True
class TestFetchStagePage:
def test_uses_full_url_when_next_url_set(self) -> None:
connector = _build_connector()
with patch.object(
connector.canvas_client, "get", return_value=([{"id": 1}], None)
) as mock_get:
result, next_url = connector._fetch_stage_page(
next_url="https://myschool.instructure.com/api/v1/courses?page=2",
endpoint="courses/1/pages",
params={"per_page": "100"},
)
mock_get.assert_called_once_with(
full_url="https://myschool.instructure.com/api/v1/courses?page=2"
)
assert result == [{"id": 1}]
def test_uses_endpoint_and_params_when_no_next_url(self) -> None:
connector = _build_connector()
with patch.object(
connector.canvas_client, "get", return_value=([{"id": 1}], None)
) as mock_get:
result, next_url = connector._fetch_stage_page(
next_url=None,
endpoint="courses/1/pages",
params={"per_page": "100"},
)
mock_get.assert_called_once_with(
endpoint="courses/1/pages", params={"per_page": "100"}
)
def test_returns_empty_list_for_none_response(self) -> None:
connector = _build_connector()
with patch.object(connector.canvas_client, "get", return_value=(None, None)):
result, next_url = connector._fetch_stage_page(
next_url=None,
endpoint="courses/1/pages",
params={},
)
assert result == []
assert next_url is None
class TestProcessItems:
def test_pages_in_window_converted(self) -> None:
connector = _build_connector()
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
results, early_exit = connector._process_items(
response=[_mock_page(10, "Syllabus", "2025-06-15T12:00:00Z")],
stage=CanvasStage.PAGES,
course_id=1,
start=start,
end=end,
include_permissions=False,
)
assert len(results) == 1
assert isinstance(results[0], Document)
assert early_exit is False
def test_pages_outside_window_skipped(self) -> None:
connector = _build_connector()
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
results, early_exit = connector._process_items(
response=[_mock_page(10, "Old", "2025-01-01T12:00:00Z")],
stage=CanvasStage.PAGES,
course_id=1,
start=start,
end=end,
include_permissions=False,
)
assert results == []
assert early_exit is True
def test_assignments_in_window_converted(self) -> None:
connector = _build_connector()
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
results, early_exit = connector._process_items(
response=[_mock_assignment(20, "HW1", 1, "2025-06-15T12:00:00Z")],
stage=CanvasStage.ASSIGNMENTS,
course_id=1,
start=start,
end=end,
include_permissions=False,
)
assert len(results) == 1
assert isinstance(results[0], Document)
assert early_exit is False
def test_announcements_in_window_converted(self) -> None:
connector = _build_connector()
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
results, early_exit = connector._process_items(
response=[_mock_announcement(30, "News", 1, "2025-06-15T12:00:00Z")],
stage=CanvasStage.ANNOUNCEMENTS,
course_id=1,
start=start,
end=end,
include_permissions=False,
)
assert len(results) == 1
assert isinstance(results[0], Document)
assert early_exit is False
def test_bad_item_yields_connector_failure(self) -> None:
connector = _build_connector()
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
bad_page = {
"page_id": 10,
"url": "test",
"title": "Test",
"body": None,
"created_at": "2025-06-15T12:00:00Z",
"updated_at": "bad-date",
}
results, early_exit = connector._process_items(
response=[bad_page],
stage=CanvasStage.PAGES,
course_id=1,
start=start,
end=end,
include_permissions=False,
)
assert len(results) == 1
assert isinstance(results[0], ConnectorFailure)
def test_page_early_exit_on_old_item(self) -> None:
"""Pages sorted desc — item before start triggers early exit."""
connector = _build_connector()
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
results, early_exit = connector._process_items(
response=[
_mock_page(10, "New", "2025-06-15T12:00:00Z"),
_mock_page(11, "Old", "2025-05-01T12:00:00Z"),
_mock_page(12, "Older", "2025-04-01T12:00:00Z"),
],
stage=CanvasStage.PAGES,
course_id=1,
start=start,
end=end,
include_permissions=False,
)
assert len(results) == 1
assert early_exit is True
class TestMaybeAttachPermissions:
def test_attaches_permissions_when_true(self) -> None:
connector = _build_connector()
doc = MagicMock(spec=Document)
doc.external_access = None
expected_access = ExternalAccess(
external_user_emails={"student@school.edu"},
external_user_group_ids=set(),
is_public=False,
)
with patch.object(
connector, "_get_course_permissions", return_value=expected_access
):
result = connector._maybe_attach_permissions(
doc, course_id=1, include_permissions=True
)
assert result.external_access == expected_access
def test_no_op_when_false(self) -> None:
connector = _build_connector()
doc = MagicMock(spec=Document)
doc.external_access = None
result = connector._maybe_attach_permissions(
doc, course_id=1, include_permissions=False
)
assert result.external_access is None

View File

@@ -0,0 +1,172 @@
"""
Tests verifying that GithubConnector implements SlimConnector and SlimConnectorWithPermSync
correctly, and that pruning uses the cheap slim path (no lazy loading).
"""
from collections.abc import Generator
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
import pytest
from onyx.access.models import ExternalAccess
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import SlimDocument
def _make_pr(html_url: str) -> MagicMock:
pr = MagicMock()
pr.html_url = html_url
pr.pull_request = None
# commits and changed_files should never be accessed during slim retrieval
type(pr).commits = PropertyMock(side_effect=AssertionError("lazy load triggered"))
type(pr).changed_files = PropertyMock(
side_effect=AssertionError("lazy load triggered")
)
return pr
def _make_issue(html_url: str) -> MagicMock:
issue = MagicMock()
issue.html_url = html_url
issue.pull_request = None
return issue
def _make_connector(include_issues: bool = False) -> GithubConnector:
connector = GithubConnector(
repo_owner="test-org",
repositories="test-repo",
include_prs=True,
include_issues=include_issues,
)
connector.github_client = MagicMock()
return connector
@pytest.fixture(autouse=True)
def patch_deserialize_repository(mock_repo: MagicMock) -> Generator[None, None, None]:
with patch(
"onyx.connectors.github.connector.deserialize_repository",
return_value=mock_repo,
):
yield
@pytest.fixture
def mock_repo() -> MagicMock:
repo = MagicMock()
repo.name = "test-repo"
repo.id = 123
repo.raw_headers = {"x-github-request-id": "test"}
repo.raw_data = {"id": 123, "name": "test-repo", "full_name": "test-org/test-repo"}
prs = [
_make_pr(f"https://github.com/test-org/test-repo/pull/{i}") for i in range(1, 4)
]
mock_paginated = MagicMock()
mock_paginated.get_page.side_effect = lambda page: prs if page == 0 else []
repo.get_pulls.return_value = mock_paginated
return repo
def test_github_connector_implements_slim_connector() -> None:
connector = _make_connector()
assert isinstance(connector, SlimConnector)
def test_github_connector_implements_slim_connector_with_perm_sync() -> None:
connector = _make_connector()
assert isinstance(connector, SlimConnectorWithPermSync)
def test_retrieve_all_slim_docs_returns_pr_urls(mock_repo: MagicMock) -> None:
connector = _make_connector()
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
batches = list(connector.retrieve_all_slim_docs())
all_docs = [doc for batch in batches for doc in batch]
assert len(all_docs) == 3
assert all(isinstance(doc, SlimDocument) for doc in all_docs)
assert {doc.id for doc in all_docs if isinstance(doc, SlimDocument)} == {
"https://github.com/test-org/test-repo/pull/1",
"https://github.com/test-org/test-repo/pull/2",
"https://github.com/test-org/test-repo/pull/3",
}
def test_retrieve_all_slim_docs_has_no_external_access(mock_repo: MagicMock) -> None:
"""Pruning does not need permissions — external_access should be None."""
connector = _make_connector()
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
batches = list(connector.retrieve_all_slim_docs())
all_docs = [doc for batch in batches for doc in batch]
assert all(doc.external_access is None for doc in all_docs)
def test_retrieve_all_slim_docs_perm_sync_populates_external_access(
mock_repo: MagicMock,
) -> None:
connector = _make_connector()
mock_access = MagicMock(spec=ExternalAccess)
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
with patch(
"onyx.connectors.github.connector.get_external_access_permission",
return_value=mock_access,
) as mock_perm:
batches = list(connector.retrieve_all_slim_docs_perm_sync())
# permission fetched at least once per repo (once per page in checkpoint-based flow)
mock_perm.assert_called_with(mock_repo, connector.github_client)
all_docs = [doc for batch in batches for doc in batch]
assert all(doc.external_access is mock_access for doc in all_docs)
def test_retrieve_all_slim_docs_skips_pr_issues(mock_repo: MagicMock) -> None:
"""Issues that are actually PRs should be skipped when include_issues=True."""
connector = _make_connector(include_issues=True)
pr_issue = MagicMock()
pr_issue.html_url = "https://github.com/test-org/test-repo/pull/99"
pr_issue.pull_request = MagicMock() # non-None means it's a PR
real_issue = _make_issue("https://github.com/test-org/test-repo/issues/1")
issues = [pr_issue, real_issue]
mock_issues_paginated = MagicMock()
mock_issues_paginated.get_page.side_effect = lambda page: (
issues if page == 0 else []
)
mock_repo.get_issues.return_value = mock_issues_paginated
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
batches = list(connector.retrieve_all_slim_docs())
issue_ids = {
doc.id
for batch in batches
for doc in batch
if isinstance(doc, SlimDocument) and "issues" in doc.id
}
assert issue_ids == {"https://github.com/test-org/test-repo/issues/1"}
def test_pruning_routes_to_slim_connector_path(mock_repo: MagicMock) -> None:
"""extract_ids_from_runnable_connector must use SlimConnector, not CheckpointedConnector."""
connector = _make_connector()
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
# If the CheckpointedConnector fallback were used instead, it would call
# load_from_checkpoint which hits _convert_pr_to_document and lazy loads.
# We verify the slim path is taken by checking load_from_checkpoint is NOT called.
with patch.object(connector, "load_from_checkpoint") as mock_load:
result = extract_ids_from_runnable_connector(connector)
mock_load.assert_not_called()
assert len(result.raw_id_to_parent) == 3
assert "https://github.com/test-org/test-repo/pull/1" in result.raw_id_to_parent

View File

@@ -0,0 +1,200 @@
"""Unit tests for GoogleDriveConnector slim retrieval routing.
Verifies that:
- GoogleDriveConnector implements SlimConnector so pruning takes the ID-only path
- retrieve_all_slim_docs() calls _extract_slim_docs_from_google_drive with include_permissions=False
- retrieve_all_slim_docs_perm_sync() calls _extract_slim_docs_from_google_drive with include_permissions=True
- celery_utils routing picks retrieve_all_slim_docs() for GoogleDriveConnector
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_drive.models import DriveRetrievalStage
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import SlimDocument
from onyx.utils.threadpool_concurrency import ThreadSafeDict
def _make_done_checkpoint() -> GoogleDriveCheckpoint:
return GoogleDriveCheckpoint(
retrieved_folder_and_drive_ids=set(),
completion_stage=DriveRetrievalStage.DONE,
completion_map=ThreadSafeDict(),
all_retrieved_file_ids=set(),
has_more=False,
)
def _make_connector() -> GoogleDriveConnector:
connector = GoogleDriveConnector(include_my_drives=True)
connector._creds = MagicMock()
connector._primary_admin_email = "admin@example.com"
return connector
class TestGoogleDriveSlimConnectorInterface:
def test_implements_slim_connector(self) -> None:
connector = _make_connector()
assert isinstance(connector, SlimConnector)
def test_implements_slim_connector_with_perm_sync(self) -> None:
connector = _make_connector()
assert isinstance(connector, SlimConnectorWithPermSync)
def test_slim_connector_checked_before_perm_sync(self) -> None:
"""SlimConnector must appear before SlimConnectorWithPermSync in MRO
so celery_utils isinstance check routes to retrieve_all_slim_docs()."""
mro = GoogleDriveConnector.__mro__
slim_idx = mro.index(SlimConnector)
perm_sync_idx = mro.index(SlimConnectorWithPermSync)
assert slim_idx < perm_sync_idx
class TestRetrieveAllSlimDocs:
def test_does_not_call_extract_when_checkpoint_is_done(self) -> None:
connector = _make_connector()
slim_doc = MagicMock(
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
)
with patch.object(
connector, "build_dummy_checkpoint", return_value=_make_done_checkpoint()
):
with patch.object(
connector,
"_extract_slim_docs_from_google_drive",
return_value=iter([[slim_doc]]),
) as mock_extract:
list(connector.retrieve_all_slim_docs())
mock_extract.assert_not_called() # loop exits immediately since checkpoint is DONE
def test_calls_extract_with_include_permissions_false_non_done_checkpoint(
self,
) -> None:
connector = _make_connector()
slim_doc = MagicMock(
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
)
# Checkpoint starts at START, _extract advances it to DONE
with patch.object(connector, "build_dummy_checkpoint") as mock_build:
start_checkpoint = GoogleDriveCheckpoint(
retrieved_folder_and_drive_ids=set(),
completion_stage=DriveRetrievalStage.START,
completion_map=ThreadSafeDict(),
all_retrieved_file_ids=set(),
has_more=False,
)
mock_build.return_value = start_checkpoint
def _advance_checkpoint(**_kwargs: object) -> object:
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
yield [slim_doc]
with patch.object(
connector,
"_extract_slim_docs_from_google_drive",
side_effect=_advance_checkpoint,
) as mock_extract:
list(connector.retrieve_all_slim_docs())
mock_extract.assert_called_once()
_, kwargs = mock_extract.call_args
assert kwargs.get("include_permissions") is False
def test_yields_slim_documents(self) -> None:
connector = _make_connector()
slim_doc = MagicMock(
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
)
start_checkpoint = GoogleDriveCheckpoint(
retrieved_folder_and_drive_ids=set(),
completion_stage=DriveRetrievalStage.START,
completion_map=ThreadSafeDict(),
all_retrieved_file_ids=set(),
has_more=False,
)
with patch.object(
connector, "build_dummy_checkpoint", return_value=start_checkpoint
):
def _advance_and_yield(**_kwargs: object) -> object:
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
yield [slim_doc]
with patch.object(
connector,
"_extract_slim_docs_from_google_drive",
side_effect=_advance_and_yield,
):
batches = list(connector.retrieve_all_slim_docs())
assert len(batches) == 1
assert batches[0][0] is slim_doc
class TestRetrieveAllSlimDocsPermSync:
def test_calls_extract_with_include_permissions_true(self) -> None:
connector = _make_connector()
slim_doc = MagicMock(
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
)
start_checkpoint = GoogleDriveCheckpoint(
retrieved_folder_and_drive_ids=set(),
completion_stage=DriveRetrievalStage.START,
completion_map=ThreadSafeDict(),
all_retrieved_file_ids=set(),
has_more=False,
)
with patch.object(
connector, "build_dummy_checkpoint", return_value=start_checkpoint
):
def _advance_and_yield(**_kwargs: object) -> object:
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
yield [slim_doc]
with patch.object(
connector,
"_extract_slim_docs_from_google_drive",
side_effect=_advance_and_yield,
) as mock_extract:
list(connector.retrieve_all_slim_docs_perm_sync())
mock_extract.assert_called_once()
_, kwargs = mock_extract.call_args
assert (
kwargs.get("include_permissions") is None
or kwargs.get("include_permissions") is True
)
class TestCeleryUtilsRouting:
def test_pruning_uses_retrieve_all_slim_docs(self) -> None:
"""extract_ids_from_runnable_connector must call retrieve_all_slim_docs,
not retrieve_all_slim_docs_perm_sync, for GoogleDriveConnector."""
connector = _make_connector()
slim_doc = MagicMock(
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
)
with (
patch.object(
connector, "retrieve_all_slim_docs", return_value=iter([[slim_doc]])
) as mock_slim,
patch.object(
connector, "retrieve_all_slim_docs_perm_sync"
) as mock_perm_sync,
):
extract_ids_from_runnable_connector(
connector, connector_type="google_drive"
)
mock_slim.assert_called_once()
mock_perm_sync.assert_not_called()

View File

@@ -0,0 +1,182 @@
from typing import Any
import pytest
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from onyx.connectors.google_utils.google_kv import get_auth_url
from onyx.connectors.google_utils.google_kv import get_google_app_cred
from onyx.connectors.google_utils.google_kv import get_service_account_key
from onyx.connectors.google_utils.google_kv import upsert_google_app_cred
from onyx.connectors.google_utils.google_kv import upsert_service_account_key
from onyx.server.documents.models import GoogleAppCredentials
from onyx.server.documents.models import GoogleAppWebCredentials
from onyx.server.documents.models import GoogleServiceAccountKey
def _make_app_creds() -> GoogleAppCredentials:
return GoogleAppCredentials(
web=GoogleAppWebCredentials(
client_id="client-id.apps.googleusercontent.com",
project_id="test-project",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
client_secret="secret",
redirect_uris=["https://example.com/callback"],
javascript_origins=["https://example.com"],
)
)
def _make_service_account_key() -> GoogleServiceAccountKey:
return GoogleServiceAccountKey(
type="service_account",
project_id="test-project",
private_key_id="private-key-id",
private_key="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
client_email="test@test-project.iam.gserviceaccount.com",
client_id="123",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test",
universe_domain="googleapis.com",
)
def test_upsert_google_app_cred_stores_dict(monkeypatch: Any) -> None:
stored: dict[str, Any] = {}
class _StubKvStore:
def store(self, key: str, value: object, encrypt: bool) -> None:
stored["key"] = key
stored["value"] = value
stored["encrypt"] = encrypt
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
upsert_google_app_cred(_make_app_creds(), DocumentSource.GOOGLE_DRIVE)
assert stored["key"] == KV_GOOGLE_DRIVE_CRED_KEY
assert stored["encrypt"] is True
assert isinstance(stored["value"], dict)
assert stored["value"]["web"]["client_id"] == "client-id.apps.googleusercontent.com"
def test_upsert_service_account_key_stores_dict(monkeypatch: Any) -> None:
stored: dict[str, Any] = {}
class _StubKvStore:
def store(self, key: str, value: object, encrypt: bool) -> None:
stored["key"] = key
stored["value"] = value
stored["encrypt"] = encrypt
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
upsert_service_account_key(_make_service_account_key(), DocumentSource.GOOGLE_DRIVE)
assert stored["key"] == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
assert stored["encrypt"] is True
assert isinstance(stored["value"], dict)
assert stored["value"]["project_id"] == "test-project"
@pytest.mark.parametrize("legacy_string", [False, True])
def test_get_google_app_cred_accepts_dict_and_legacy_string(
monkeypatch: Any, legacy_string: bool
) -> None:
payload: dict[str, Any] = _make_app_creds().model_dump(mode="json")
stored_value: object = (
payload if not legacy_string else _make_app_creds().model_dump_json()
)
class _StubKvStore:
def load(self, key: str) -> object:
assert key == KV_GOOGLE_DRIVE_CRED_KEY
return stored_value
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
creds = get_google_app_cred(DocumentSource.GOOGLE_DRIVE)
assert creds.web.client_id == "client-id.apps.googleusercontent.com"
@pytest.mark.parametrize("legacy_string", [False, True])
def test_get_service_account_key_accepts_dict_and_legacy_string(
monkeypatch: Any, legacy_string: bool
) -> None:
stored_value: object = (
_make_service_account_key().model_dump(mode="json")
if not legacy_string
else _make_service_account_key().model_dump_json()
)
class _StubKvStore:
def load(self, key: str) -> object:
assert key == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
return stored_value
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
key = get_service_account_key(DocumentSource.GOOGLE_DRIVE)
assert key.client_email == "test@test-project.iam.gserviceaccount.com"
@pytest.mark.parametrize("legacy_string", [False, True])
def test_get_auth_url_accepts_dict_and_legacy_string(
monkeypatch: Any, legacy_string: bool
) -> None:
payload = _make_app_creds().model_dump(mode="json")
stored_value: object = (
payload if not legacy_string else _make_app_creds().model_dump_json()
)
stored_state: dict[str, object] = {}
class _StubKvStore:
def load(self, key: str) -> object:
assert key == KV_GOOGLE_DRIVE_CRED_KEY
return stored_value
def store(self, key: str, value: object, encrypt: bool) -> None:
stored_state["key"] = key
stored_state["value"] = value
stored_state["encrypt"] = encrypt
class _StubFlow:
def authorization_url(self, prompt: str) -> tuple[str, None]:
assert prompt == "consent"
return "https://accounts.google.com/o/oauth2/auth?state=test-state", None
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
def _from_client_config(
_app_config: object, *, scopes: object, redirect_uri: object
) -> _StubFlow:
del scopes, redirect_uri
return _StubFlow()
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.InstalledAppFlow.from_client_config",
_from_client_config,
)
auth_url = get_auth_url(42, DocumentSource.GOOGLE_DRIVE)
assert auth_url.startswith("https://accounts.google.com")
assert stored_state["value"] == {"value": "test-state"}
assert stored_state["encrypt"] is True

View File

@@ -0,0 +1,86 @@
"""Tests for get_index_attempt_errors_across_connectors."""
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from onyx.db.index_attempt import get_index_attempt_errors_across_connectors
from onyx.db.models import IndexAttemptError
def _make_error(
id: int = 1,
cc_pair_id: int = 1,
error_type: str | None = "TimeoutError",
is_resolved: bool = False,
) -> IndexAttemptError:
"""Create a mock IndexAttemptError."""
error = MagicMock(spec=IndexAttemptError)
error.id = id
error.connector_credential_pair_id = cc_pair_id
error.error_type = error_type
error.is_resolved = is_resolved
return error
class TestGetIndexAttemptErrorsAcrossConnectors:
def test_returns_errors_and_count(self) -> None:
mock_session = MagicMock()
mock_errors = [_make_error(id=1), _make_error(id=2)]
mock_session.scalar.return_value = 2
mock_session.scalars.return_value.all.return_value = mock_errors
errors, total = get_index_attempt_errors_across_connectors(
db_session=mock_session,
)
assert total == 2
assert len(errors) == 2
def test_returns_empty_when_no_errors(self) -> None:
mock_session = MagicMock()
mock_session.scalar.return_value = 0
mock_session.scalars.return_value.all.return_value = []
errors, total = get_index_attempt_errors_across_connectors(
db_session=mock_session,
)
assert total == 0
assert errors == []
def test_null_count_returns_zero(self) -> None:
mock_session = MagicMock()
mock_session.scalar.return_value = None
mock_session.scalars.return_value.all.return_value = []
errors, total = get_index_attempt_errors_across_connectors(
db_session=mock_session,
)
assert total == 0
def test_passes_filters_to_query(self) -> None:
"""Verify that filter parameters result in .where() calls on the statement."""
mock_session = MagicMock()
mock_session.scalar.return_value = 0
mock_session.scalars.return_value.all.return_value = []
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 12, 31, tzinfo=timezone.utc)
# Should not raise — just verifying the function accepts all filter params
get_index_attempt_errors_across_connectors(
db_session=mock_session,
cc_pair_id=42,
error_type="TimeoutError",
start_time=start,
end_time=end,
unresolved_only=True,
page=2,
page_size=10,
)
# The function should have called scalar (for count) and scalars (for results)
assert mock_session.scalar.called
assert mock_session.scalars.called

View File

@@ -1,9 +1,13 @@
import io
from typing import cast
from unittest.mock import MagicMock
import openpyxl
from openpyxl.worksheet.worksheet import Worksheet
from onyx.file_processing.extract_file_text import _clean_worksheet_matrix
from onyx.file_processing.extract_file_text import _worksheet_to_matrix
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
from onyx.file_processing.extract_file_text import xlsx_to_text
@@ -196,3 +200,182 @@ class TestXlsxToText:
assert "r1c1" in lines[0] and "r1c2" in lines[0]
assert "r2c1" in lines[1] and "r2c2" in lines[1]
assert "r3c1" in lines[2] and "r3c2" in lines[2]
class TestWorksheetToMatrixJaggedRows:
"""openpyxl read_only mode can yield rows of differing widths when
trailing cells are empty. The matrix must be padded to a rectangle
so downstream column cleanup can index safely."""
def test_pads_shorter_trailing_rows(self) -> None:
ws = MagicMock()
ws.iter_rows.return_value = iter(
[
("A", "B", "C"),
("X", "Y"),
("P",),
]
)
matrix = _worksheet_to_matrix(ws)
assert matrix == [["A", "B", "C"], ["X", "Y", ""], ["P", "", ""]]
def test_pads_when_first_row_is_shorter(self) -> None:
ws = MagicMock()
ws.iter_rows.return_value = iter(
[
("A",),
("X", "Y", "Z"),
]
)
matrix = _worksheet_to_matrix(ws)
assert matrix == [["A", "", ""], ["X", "Y", "Z"]]
def test_clean_worksheet_matrix_no_index_error_on_jagged_rows(self) -> None:
"""Regression: previously raised IndexError when a later row was
shorter than the first row and the out-of-range column on the
first row was empty (so the short-circuit in `all()` did not
save us)."""
ws = MagicMock()
ws.iter_rows.return_value = iter(
[
("A", "", "", "B"),
("X", "Y"),
]
)
matrix = _worksheet_to_matrix(ws)
# Must not raise.
cleaned = _clean_worksheet_matrix(matrix)
assert cleaned == [["A", "", "", "B"], ["X", "Y", "", ""]]
class TestXlsxSheetExtraction:
def test_one_tuple_per_sheet(self) -> None:
xlsx = _make_xlsx(
{
"Revenue": [["Month", "Amount"], ["Jan", "100"]],
"Expenses": [["Category", "Cost"], ["Rent", "500"]],
}
)
sheets = xlsx_sheet_extraction(xlsx)
assert len(sheets) == 2
# Order preserved from workbook sheet order
titles = [title for _csv, title in sheets]
assert titles == ["Revenue", "Expenses"]
# Content present in the right tuple
revenue_csv, _ = sheets[0]
expenses_csv, _ = sheets[1]
assert "Month" in revenue_csv
assert "Jan" in revenue_csv
assert "Category" in expenses_csv
assert "Rent" in expenses_csv
def test_tuple_structure_is_csv_text_then_title(self) -> None:
"""The tuple order is (csv_text, sheet_title) — pin it so callers
that unpack positionally don't silently break."""
xlsx = _make_xlsx({"MySheet": [["a", "b"]]})
sheets = xlsx_sheet_extraction(xlsx)
assert len(sheets) == 1
csv_text, title = sheets[0]
assert title == "MySheet"
assert "a" in csv_text
assert "b" in csv_text
def test_empty_sheet_is_skipped(self) -> None:
"""A sheet whose CSV output is empty/whitespace-only should NOT
appear in the result — the `if csv_text.strip():` guard filters
it out."""
xlsx = _make_xlsx(
{
"Data": [["a", "b"]],
"Empty": [],
}
)
sheets = xlsx_sheet_extraction(xlsx)
assert len(sheets) == 1
assert sheets[0][1] == "Data"
def test_empty_workbook_returns_empty_list(self) -> None:
"""All sheets empty → empty list (not a list of empty tuples)."""
xlsx = _make_xlsx({"Sheet1": [], "Sheet2": []})
sheets = xlsx_sheet_extraction(xlsx)
assert sheets == []
def test_single_sheet(self) -> None:
xlsx = _make_xlsx({"Only": [["x", "y"], ["1", "2"]]})
sheets = xlsx_sheet_extraction(xlsx)
assert len(sheets) == 1
csv_text, title = sheets[0]
assert title == "Only"
assert "x" in csv_text
assert "1" in csv_text
def test_bad_zip_returns_empty_list(self) -> None:
bad_file = io.BytesIO(b"not a zip file")
sheets = xlsx_sheet_extraction(bad_file, file_name="test.xlsx")
assert sheets == []
def test_bad_zip_tilde_file_returns_empty_list(self) -> None:
"""`~$`-prefixed files are Excel lock files; failure should log
at debug (not warning) and still return []."""
bad_file = io.BytesIO(b"not a zip file")
sheets = xlsx_sheet_extraction(bad_file, file_name="~$temp.xlsx")
assert sheets == []
def test_csv_content_matches_xlsx_to_text_per_sheet(self) -> None:
"""For a single-sheet workbook, xlsx_to_text output should equal
the csv_text from xlsx_sheet_extraction — they share the same
per-sheet CSV-ification logic."""
single_sheet_data = [["Name", "Age"], ["Alice", "30"]]
expected_text = xlsx_to_text(_make_xlsx({"People": single_sheet_data}))
sheets = xlsx_sheet_extraction(_make_xlsx({"People": single_sheet_data}))
assert len(sheets) == 1
csv_text, title = sheets[0]
assert title == "People"
assert csv_text.strip() == expected_text.strip()
def test_commas_in_cells_are_quoted(self) -> None:
xlsx = _make_xlsx({"S1": [["hello, world", "normal"]]})
sheets = xlsx_sheet_extraction(xlsx)
assert len(sheets) == 1
csv_text, _ = sheets[0]
assert '"hello, world"' in csv_text
def test_long_empty_row_run_capped_within_sheet(self) -> None:
"""The matrix cleanup applies per-sheet: >2 empty rows collapse
to 2, which keeps the sheet non-empty and it still appears in
the result."""
xlsx = _make_xlsx(
{
"S1": [
["header"],
[""],
[""],
[""],
[""],
["data"],
]
}
)
sheets = xlsx_sheet_extraction(xlsx)
assert len(sheets) == 1
csv_text, _ = sheets[0]
lines = csv_text.strip().split("\n")
# header + 2 empty (capped) + data = 4 lines
assert len(lines) == 4
assert "header" in lines[0]
assert "data" in lines[-1]
def test_sheet_title_with_special_chars_preserved(self) -> None:
"""Spaces, punctuation, unicode in sheet titles are preserved
verbatim — the title is used as a link anchor downstream."""
xlsx = _make_xlsx(
{
"Q1 Revenue (USD)": [["a", "b"]],
"Données": [["c", "d"]],
}
)
sheets = xlsx_sheet_extraction(xlsx)
titles = [title for _csv, title in sheets]
assert "Q1 Revenue (USD)" in titles
assert "Données" in titles

View File

@@ -0,0 +1,558 @@
"""End-to-end tests for `TabularChunker.chunk_section`.
Each test is structured as:
INPUT — the CSV text passed to the chunker + token budget + link
EXPECTED — the exact chunk texts the chunker should emit
ACT — a single call to `chunk_section`
ASSERT — literal equality against the expected chunk texts
A character-level tokenizer (1 char == 1 token) is used so token-budget
arithmetic is deterministic and expected chunks can be spelled out
exactly.
"""
from onyx.connectors.models import Section
from onyx.connectors.models import TabularSection
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.tabular_section_chunker import TabularChunker
from onyx.natural_language_processing.utils import BaseTokenizer
class CharTokenizer(BaseTokenizer):
def encode(self, string: str) -> list[int]:
return [ord(c) for c in string]
def tokenize(self, string: str) -> list[str]:
return list(string)
def decode(self, tokens: list[int]) -> str:
return "".join(chr(t) for t in tokens)
def _make_chunker() -> TabularChunker:
return TabularChunker(tokenizer=CharTokenizer())
_DEFAULT_LINK = "https://example.com/doc"
def _tabular_section(
text: str,
link: str = _DEFAULT_LINK,
heading: str | None = "sheet:Test",
) -> Section:
return TabularSection(text=text, link=link, heading=heading)
class TestTabularChunkerChunkSection:
def test_simple_csv_all_rows_fit_one_chunk(self) -> None:
# --- INPUT -----------------------------------------------------
csv_text = "Name,Age,City\n" "Alice,30,NYC\n" "Bob,25,SF\n"
heading = "sheet:People"
content_token_limit = 500
# --- EXPECTED --------------------------------------------------
expected_texts = [
(
"sheet:People\n"
"Columns: Name, Age, City\n"
"Name=Alice, Age=30, City=NYC\n"
"Name=Bob, Age=25, City=SF"
),
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
assert [p.is_continuation for p in out.payloads] == [False]
assert all(p.links == {0: _DEFAULT_LINK} for p in out.payloads)
assert out.accumulator.is_empty()
def test_overflow_splits_into_two_deterministic_chunks(self) -> None:
# --- INPUT -----------------------------------------------------
# prelude = "sheet:S\nColumns: col, val" (25 chars = 25 tokens)
# At content_token_limit=57, row_budget = max(16, 57-31-1) = 25.
# Each row "col=a, val=1" is 12 tokens; two rows + \n = 25 (fits),
# three rows + 2×\n = 38 (overflows) → split after 2 rows.
csv_text = "col,val\n" "a,1\n" "b,2\n" "c,3\n" "d,4\n"
heading = "sheet:S"
content_token_limit = 57
# --- EXPECTED --------------------------------------------------
expected_texts = [
("sheet:S\n" "Columns: col, val\n" "col=a, val=1\n" "col=b, val=2"),
("sheet:S\n" "Columns: col, val\n" "col=c, val=3\n" "col=d, val=4"),
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
# First chunk is fresh; subsequent chunks mark as continuations.
assert [p.is_continuation for p in out.payloads] == [False, True]
# Link carries through every chunk.
assert all(p.links == {0: _DEFAULT_LINK} for p in out.payloads)
# Add back in shortly
# def test_header_only_csv_produces_single_prelude_chunk(self) -> None:
# # --- INPUT -----------------------------------------------------
# csv_text = "col1,col2\n"
# link = "sheet:Headers"
# # --- EXPECTED --------------------------------------------------
# expected_texts = [
# "sheet:Headers\nColumns: col1, col2",
# ]
# # --- ACT -------------------------------------------------------
# out = _make_chunker().chunk_section(
# _tabular_section(csv_text, link=link),
# AccumulatorState(),
# content_token_limit=500,
# )
# # --- ASSERT ----------------------------------------------------
# assert [p.text for p in out.payloads] == expected_texts
def test_empty_cells_dropped_from_chunk_text(self) -> None:
# --- INPUT -----------------------------------------------------
# Alice's Age is empty; Bob's City is empty. Empty cells should
# not appear as `field=` pairs in the output.
csv_text = "Name,Age,City\n" "Alice,,NYC\n" "Bob,25,\n"
heading = "sheet:P"
# --- EXPECTED --------------------------------------------------
expected_texts = [
(
"sheet:P\n"
"Columns: Name, Age, City\n"
"Name=Alice, City=NYC\n"
"Name=Bob, Age=25"
),
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=500,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
def test_quoted_commas_in_csv_preserved_as_one_field(self) -> None:
# --- INPUT -----------------------------------------------------
# "Hello, world" is quoted in the CSV, so csv.reader parses it as
# a single field. The surrounding quotes are stripped during
# decoding, so the chunk text carries the bare value.
csv_text = "Name,Notes\n" 'Alice,"Hello, world"\n'
heading = "sheet:P"
# --- EXPECTED --------------------------------------------------
expected_texts = [
("sheet:P\n" "Columns: Name, Notes\n" "Name=Alice, Notes=Hello, world"),
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=500,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
def test_blank_rows_in_csv_are_skipped(self) -> None:
# --- INPUT -----------------------------------------------------
# Stray blank rows in the CSV (e.g. export artifacts) shouldn't
# produce ghost rows in the output.
csv_text = "A,B\n" "\n" "1,2\n" "\n" "\n" "3,4\n"
heading = "sheet:S"
# --- EXPECTED --------------------------------------------------
expected_texts = [
("sheet:S\n" "Columns: A, B\n" "A=1, B=2\n" "A=3, B=4"),
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=500,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
def test_accumulator_flushes_before_tabular_chunks(self) -> None:
# --- INPUT -----------------------------------------------------
# A text accumulator was populated by the prior text section.
# Tabular sections are structural boundaries, so the pending
# text is flushed as its own chunk before the tabular content.
pending_text = "prior paragraph from an earlier text section"
pending_link = "prev-link"
csv_text = "a,b\n" "1,2\n"
heading = "sheet:S"
# --- EXPECTED --------------------------------------------------
expected_texts = [
pending_text, # flushed accumulator
("sheet:S\n" "Columns: a, b\n" "a=1, b=2"),
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(
text=pending_text,
link_offsets={0: pending_link},
),
content_token_limit=500,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
# Flushed chunk keeps the prior text's link; tabular chunk uses
# the tabular section's link.
assert out.payloads[0].links == {0: pending_link}
assert out.payloads[1].links == {0: _DEFAULT_LINK}
# Accumulator resets — tabular section is a structural boundary.
assert out.accumulator.is_empty()
def test_multi_row_packing_under_budget_emits_single_chunk(self) -> None:
# --- INPUT -----------------------------------------------------
# Three small rows (20 tokens each) under a generous
# content_token_limit=100 should pack into ONE chunk — prelude
# emitted once, rows stacked beneath it.
csv_text = (
"x\n" "aaaaaaaaaaaaaaaaaa\n" "bbbbbbbbbbbbbbbbbb\n" "cccccccccccccccccc\n"
)
heading = "S"
content_token_limit = 100
# --- EXPECTED --------------------------------------------------
# Each formatted row "x=<18-char value>" = 20 tokens.
# Full chunk with sheet + Columns + 3 rows =
# 1 + 1 + 10 + 1 + (20 + 1 + 20 + 1 + 20) = 75 tokens ≤ 100.
# Single chunk carries all three rows.
expected_texts = [
"S\n"
"Columns: x\n"
"x=aaaaaaaaaaaaaaaaaa\n"
"x=bbbbbbbbbbbbbbbbbb\n"
"x=cccccccccccccccccc"
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
assert [p.is_continuation for p in out.payloads] == [False]
assert all(len(p.text) <= content_token_limit for p in out.payloads)
def test_packing_reserves_prelude_budget_so_every_chunk_has_full_prelude(
self,
) -> None:
# --- INPUT -----------------------------------------------------
# Budget (30) is large enough for all 5 bare rows (row_block =
# 24 tokens) to pack as one chunk if the prelude were optional,
# but [sheet] + Columns + 5_rows would be 41 tokens > 30. The
# packing logic reserves space for the prelude: only 2 rows
# pack per chunk (17 prelude overhead + 9 rows = 26 ≤ 30).
# Every emitted chunk therefore carries its full prelude rather
# than dropping Columns at emit time.
csv_text = "x\n" "aa\n" "bb\n" "cc\n" "dd\n" "ee\n"
heading = "S"
content_token_limit = 30
# --- EXPECTED --------------------------------------------------
# Prelude overhead = 'S\nColumns: x\n' = 1+1+10+1 = 13.
# Each row "x=XX" = 4 tokens, row separator "\n" = 1.
# 3 rows: 13 + (4+1+4+1+4) = 27 ≤ 30 ✓
# 4 rows: 13 + (4+1+4+1+4+1+4) = 32 > 30 ✗
# → 3 rows in the first chunk, 2 rows in the second.
expected_texts = [
"S\nColumns: x\nx=aa\nx=bb\nx=cc",
"S\nColumns: x\nx=dd\nx=ee",
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
# Every chunk fits under the budget AND carries its full
# prelude — that's the whole point of this check.
assert all(len(p.text) <= content_token_limit for p in out.payloads)
assert all("Columns: x" in p.text for p in out.payloads)
def test_oversized_row_splits_into_field_pieces_no_prelude(self) -> None:
# --- INPUT -----------------------------------------------------
# Single-row CSV whose formatted form ("field 1=1, ..." = 53
# tokens) exceeds content_token_limit (20). Per the chunker's
# rules, oversized rows are split at field boundaries into
# pieces each ≤ max_tokens, and no prelude is added to split
# pieces (they already consume the full budget). A 53-token row
# packs into 3 field-boundary pieces under a 20-token budget.
csv_text = "field 1,field 2,field 3,field 4,field 5\n" "1,2,3,4,5\n"
heading = "S"
content_token_limit = 20
# --- EXPECTED --------------------------------------------------
# Row = "field 1=1, field 2=2, field 3=3, field 4=4, field 5=5"
# Fields @ 9 tokens each, ", " sep = 2 tokens.
# "field 1=1, field 2=2" = 9+2+9 = 20 tokens ≤ 20 ✓
# + ", field 3=3" = 20+2+9 = 31 > 20 → flush, start new
# "field 3=3, field 4=4" = 9+2+9 = 20 ≤ 20 ✓
# + ", field 5=5" = 20+2+9 = 31 > 20 → flush, start new
# "field 5=5" = 9 ≤ 20 ✓
# ceil(53 / 20) = 3 chunks.
expected_texts = [
"field 1=1, field 2=2",
"field 3=3, field 4=4",
"field 5=5",
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
# Invariant: no chunk exceeds max_tokens.
assert all(len(p.text) <= content_token_limit for p in out.payloads)
# is_continuation: first chunk False, rest True.
assert [p.is_continuation for p in out.payloads] == [False, True, True]
def test_empty_tabular_section_flushes_accumulator_and_resets_it(
self,
) -> None:
# --- INPUT -----------------------------------------------------
# Tabular sections are structural boundaries, so any pending text
# buffer is flushed to a chunk before parsing the tabular content
# — even if the tabular section itself is empty. The accumulator
# is then reset.
pending_text = "prior paragraph"
pending_link_offsets = {0: "prev-link"}
# --- EXPECTED --------------------------------------------------
expected_texts = [pending_text]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section("", heading="sheet:Empty"),
AccumulatorState(
text=pending_text,
link_offsets=pending_link_offsets,
),
content_token_limit=500,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
assert out.accumulator.is_empty()
def test_single_oversized_field_token_splits_at_id_boundaries(self) -> None:
# --- INPUT -----------------------------------------------------
# A single `field=value` pair that itself exceeds max_tokens can't
# be split at field boundaries — there's only one field. The
# chunker falls back to encoding the pair to token ids and
# slicing at max-token-sized windows.
#
# CSV has one column "x" with a 50-char value. Formatted pair =
# "x=" + 50 a's = 52 tokens. Budget = 10.
csv_text = "x\n" + ("a" * 50) + "\n"
heading = "S"
content_token_limit = 10
# --- EXPECTED --------------------------------------------------
# 52-char pair at 10 tokens per window = 6 pieces:
# [0:10) "x=aaaaaaaa" (10)
# [10:20) "aaaaaaaaaa" (10)
# [20:30) "aaaaaaaaaa" (10)
# [30:40) "aaaaaaaaaa" (10)
# [40:50) "aaaaaaaaaa" (10)
# [50:52) "aa" (2)
# Split pieces carry no prelude (they already consume the budget).
expected_texts = [
"x=aaaaaaaa",
"aaaaaaaaaa",
"aaaaaaaaaa",
"aaaaaaaaaa",
"aaaaaaaaaa",
"aa",
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
# Every piece is ≤ max_tokens — the invariant the token-level
# fallback exists to enforce.
assert all(len(p.text) <= content_token_limit for p in out.payloads)
def test_underscored_column_gets_friendly_alias_in_parens(self) -> None:
# --- INPUT -----------------------------------------------------
# Column headers with underscores get a space-substituted friendly
# alias appended in parens on the `Columns:` line. Plain headers
# pass through untouched.
csv_text = "MTTR_hours,id,owner_name\n" "3,42,Alice\n"
heading = "sheet:M"
# --- EXPECTED --------------------------------------------------
expected_texts = [
(
"sheet:M\n"
"Columns: MTTR_hours (MTTR hours), id, owner_name (owner name)\n"
"MTTR_hours=3, id=42, owner_name=Alice"
),
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=500,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
def test_oversized_row_between_small_rows_preserves_flanking_chunks(
self,
) -> None:
# --- INPUT -----------------------------------------------------
# State-machine check: small row, oversized row, small row. The
# first small row should become a preluded chunk; the oversized
# row flushes it and emits split fragments without prelude; then
# the last small row picks up from wherever the split left off.
#
# Headers a,b,c,d. Row 1 and row 3 each have only column `a`
# populated (tiny). Row 2 is a "fat" row with all four columns
# populated.
csv_text = "a,b,c,d\n" "1,,,\n" "xxx,yyy,zzz,www\n" "2,,,\n"
heading = "S"
content_token_limit = 20
# --- EXPECTED --------------------------------------------------
# Prelude = 'S\nColumns: a, b, c, d\n' = 1+1+19+1 = 22 > 20, so
# sheet fits with the row but full Columns header does not.
# Row 1 formatted = "a=1" (3). build_chunk_from_scratch:
# cols+row = 20+3 = 23 > 20 → skip cols. sheet+row = 1+1+3 = 5
# ≤ 20 → chunk = "S\na=1".
# Row 2 formatted = "a=xxx, b=yyy, c=zzz, d=www" (26 > 20) →
# flush "S\na=1" and split at pair boundaries:
# "a=xxx, b=yyy, c=zzz" (19 ≤ 20 ✓)
# "d=www" (5)
# Row 3 formatted = "a=2" (3). can_pack onto "d=www" (5):
# 5 + 3 + 1 = 9 ≤ 20 ✓ → packs. Trailing fragment from the
# split absorbs the next small row, which is the current v2
# behavior (the fragment becomes `current_chunk` and the next
# small row is appended with the standard packing rules).
expected_texts = [
"S\na=1",
"a=xxx, b=yyy, c=zzz",
"d=www\na=2",
]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
assert all(len(p.text) <= content_token_limit for p in out.payloads)
def test_prelude_layering_column_header_fits_but_sheet_header_does_not(
self,
) -> None:
# --- INPUT -----------------------------------------------------
# Budget lets `Columns: x\nx=y` fit but not the additional sheet
# header on top. The chunker should add the column header and
# drop the sheet header.
#
# sheet = "LongSheetName" (13), cols = "Columns: x" (10),
# row = "x=y" (3). Budget = 15.
# cols + row: 10+1+3 = 14 ≤ 15 ✓
# sheet + cols + row: 13+1+10+1+3 = 28 > 15 ✗
csv_text = "x\n" "y\n"
heading = "LongSheetName"
content_token_limit = 15
# --- EXPECTED --------------------------------------------------
expected_texts = ["Columns: x\nx=y"]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts
def test_prelude_layering_sheet_header_fits_but_column_header_does_not(
self,
) -> None:
# --- INPUT -----------------------------------------------------
# Budget is too small for the column header but leaves room for
# the short sheet header. The chunker should fall back to just
# sheet + row (its layered "try cols, then try sheet on top of
# whatever we have" logic means sheet is attempted on the bare
# row when cols didn't fit).
#
# sheet = "S" (1), cols = "Columns: ABC, DEF" (17),
# row = "ABC=1, DEF=2" (12). Budget = 20.
# cols + row: 17+1+12 = 30 > 20 ✗
# sheet + row: 1+1+12 = 14 ≤ 20 ✓
csv_text = "ABC,DEF\n" "1,2\n"
heading = "S"
content_token_limit = 20
# --- EXPECTED --------------------------------------------------
expected_texts = ["S\nABC=1, DEF=2"]
# --- ACT -------------------------------------------------------
out = _make_chunker().chunk_section(
_tabular_section(csv_text, heading=heading),
AccumulatorState(),
content_token_limit=content_token_limit,
)
# --- ASSERT ----------------------------------------------------
assert [p.text for p in out.payloads] == expected_texts

View File

@@ -0,0 +1,188 @@
"""Unit tests for MinimalPersonaSnapshot.from_model knowledge_sources aggregation."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FederatedConnectorSource
from onyx.server.features.document_set.models import DocumentSetSummary
from onyx.server.features.persona.models import MinimalPersonaSnapshot
_STUB_DS_SUMMARY = DocumentSetSummary(
id=1,
name="stub",
description=None,
cc_pair_summaries=[],
is_up_to_date=True,
is_public=True,
users=[],
groups=[],
)
def _make_persona(**overrides: object) -> MagicMock:
"""Build a mock Persona with sensible defaults.
Every relationship defaults to empty so tests only need to set the
fields they care about.
"""
p = MagicMock()
p.id = 1
p.name = "test"
p.description = ""
p.tools = []
p.starter_messages = None
p.document_sets = []
p.hierarchy_nodes = []
p.attached_documents = []
p.user_files = []
p.llm_model_version_override = None
p.llm_model_provider_override = None
p.uploaded_image_id = None
p.icon_name = None
p.is_public = True
p.is_listed = True
p.display_priority = None
p.is_featured = False
p.builtin_persona = False
p.labels = []
p.user = None
for k, v in overrides.items():
setattr(p, k, v)
return p
def _make_cc_pair(source: DocumentSource) -> MagicMock:
cc = MagicMock()
cc.connector.source = source
cc.name = source.value
cc.id = 1
cc.access_type = "PUBLIC"
return cc
def _make_doc_set(
cc_pairs: list[MagicMock] | None = None,
fed_connectors: list[MagicMock] | None = None,
) -> MagicMock:
ds = MagicMock()
ds.id = 1
ds.name = "ds"
ds.description = None
ds.is_up_to_date = True
ds.is_public = True
ds.users = []
ds.groups = []
ds.connector_credential_pairs = cc_pairs or []
ds.federated_connectors = fed_connectors or []
return ds
def _make_federated_ds_mapping(
source: FederatedConnectorSource,
) -> MagicMock:
mapping = MagicMock()
mapping.federated_connector.source = source
mapping.federated_connector_id = 1
mapping.entities = {}
return mapping
def _make_hierarchy_node(source: DocumentSource) -> MagicMock:
node = MagicMock()
node.source = source
return node
def _make_attached_document(source: DocumentSource) -> MagicMock:
doc = MagicMock()
doc.parent_hierarchy_node = MagicMock()
doc.parent_hierarchy_node.source = source
return doc
@patch(
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
return_value=_STUB_DS_SUMMARY,
)
def test_empty_persona_has_no_knowledge_sources(_mock_ds: MagicMock) -> None:
persona = _make_persona()
snapshot = MinimalPersonaSnapshot.from_model(persona)
assert snapshot.knowledge_sources == []
@patch(
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
return_value=_STUB_DS_SUMMARY,
)
def test_user_files_adds_user_file_source(_mock_ds: MagicMock) -> None:
persona = _make_persona(user_files=[MagicMock()])
snapshot = MinimalPersonaSnapshot.from_model(persona)
assert DocumentSource.USER_FILE in snapshot.knowledge_sources
@patch(
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
return_value=_STUB_DS_SUMMARY,
)
def test_no_user_files_excludes_user_file_source(_mock_ds: MagicMock) -> None:
cc = _make_cc_pair(DocumentSource.CONFLUENCE)
ds = _make_doc_set(cc_pairs=[cc])
persona = _make_persona(document_sets=[ds])
snapshot = MinimalPersonaSnapshot.from_model(persona)
assert DocumentSource.USER_FILE not in snapshot.knowledge_sources
assert DocumentSource.CONFLUENCE in snapshot.knowledge_sources
@patch(
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
return_value=_STUB_DS_SUMMARY,
)
def test_federated_connector_in_doc_set(_mock_ds: MagicMock) -> None:
fed = _make_federated_ds_mapping(FederatedConnectorSource.FEDERATED_SLACK)
ds = _make_doc_set(fed_connectors=[fed])
persona = _make_persona(document_sets=[ds])
snapshot = MinimalPersonaSnapshot.from_model(persona)
assert DocumentSource.SLACK in snapshot.knowledge_sources
@patch(
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
return_value=_STUB_DS_SUMMARY,
)
def test_hierarchy_nodes_and_attached_documents(_mock_ds: MagicMock) -> None:
node = _make_hierarchy_node(DocumentSource.GOOGLE_DRIVE)
doc = _make_attached_document(DocumentSource.SHAREPOINT)
persona = _make_persona(hierarchy_nodes=[node], attached_documents=[doc])
snapshot = MinimalPersonaSnapshot.from_model(persona)
assert DocumentSource.GOOGLE_DRIVE in snapshot.knowledge_sources
assert DocumentSource.SHAREPOINT in snapshot.knowledge_sources
@patch(
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
return_value=_STUB_DS_SUMMARY,
)
def test_all_source_types_combined(_mock_ds: MagicMock) -> None:
cc = _make_cc_pair(DocumentSource.CONFLUENCE)
fed = _make_federated_ds_mapping(FederatedConnectorSource.FEDERATED_SLACK)
ds = _make_doc_set(cc_pairs=[cc], fed_connectors=[fed])
node = _make_hierarchy_node(DocumentSource.GOOGLE_DRIVE)
doc = _make_attached_document(DocumentSource.SHAREPOINT)
persona = _make_persona(
document_sets=[ds],
hierarchy_nodes=[node],
attached_documents=[doc],
user_files=[MagicMock()],
)
snapshot = MinimalPersonaSnapshot.from_model(persona)
sources = set(snapshot.knowledge_sources)
assert sources == {
DocumentSource.CONFLUENCE,
DocumentSource.SLACK,
DocumentSource.GOOGLE_DRIVE,
DocumentSource.SHAREPOINT,
DocumentSource.USER_FILE,
}

View File

@@ -100,6 +100,39 @@ class TestGenerateOllamaDisplayName:
result = generate_ollama_display_name("llama3.3:70b")
assert "3.3" in result or "3 3" in result # Either format is acceptable
def test_non_size_tag_shown(self) -> None:
"""Test that non-size tags like 'e4b' are included in the display name."""
result = generate_ollama_display_name("gemma4:e4b")
assert "Gemma" in result
assert "4" in result
assert "E4B" in result
def test_size_with_cloud_modifier(self) -> None:
"""Test size tag with cloud modifier."""
result = generate_ollama_display_name("deepseek-v3.1:671b-cloud")
assert "DeepSeek" in result
assert "671B" in result
assert "Cloud" in result
def test_size_with_multiple_modifiers(self) -> None:
"""Test size tag with multiple modifiers."""
result = generate_ollama_display_name("qwen3-vl:235b-instruct-cloud")
assert "Qwen" in result
assert "235B" in result
assert "Instruct" in result
assert "Cloud" in result
def test_quantization_tag_shown(self) -> None:
"""Test that quantization tags are included in the display name."""
result = generate_ollama_display_name("llama3:q4_0")
assert "Llama" in result
assert "Q4_0" in result
def test_cloud_only_tag(self) -> None:
"""Test standalone cloud tag."""
result = generate_ollama_display_name("glm-4.6:cloud")
assert "CLOUD" in result
class TestStripOpenrouterVendorPrefix:
"""Tests for OpenRouter vendor prefix stripping."""

View File

@@ -95,9 +95,9 @@ class TestForceAddSearchToolGuard:
without a vector DB."""
import inspect
from onyx.tools.tool_constructor import construct_tools
from onyx.tools.tool_constructor import _construct_tools_impl
source = inspect.getsource(construct_tools)
source = inspect.getsource(_construct_tools_impl)
assert (
"DISABLE_VECTOR_DB" in source
), "construct_tools should reference DISABLE_VECTOR_DB to suppress force-adding SearchTool"

Some files were not shown because too many files have changed in this diff Show More