mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-16 23:16:46 +00:00
Compare commits
60 Commits
fix/agent-
...
v3.2.0-clo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
546da624a1 | ||
|
|
1a88dea760 | ||
|
|
53d2d647c5 | ||
|
|
560a8f7ab4 | ||
|
|
eaabb19c72 | ||
|
|
d3e5e16150 | ||
|
|
d3739611ba | ||
|
|
73f9a47364 | ||
|
|
a808445d96 | ||
|
|
c31215197a | ||
|
|
9ebd9ebd73 | ||
|
|
f0bb0a6bb0 | ||
|
|
01bec19d19 | ||
|
|
7b40c2cde7 | ||
|
|
e2c38d2899 | ||
|
|
24768f9e4f | ||
|
|
aec1c169b6 | ||
|
|
5a16ad3473 | ||
|
|
7e28e59f23 | ||
|
|
879ae6c02d | ||
|
|
f84f367eb4 | ||
|
|
d81efe3877 | ||
|
|
d4619f93c4 | ||
|
|
70fcfb1d73 | ||
|
|
32ba393b32 | ||
|
|
f9d2bf78ed | ||
|
|
5567a078fe | ||
|
|
fc0e8560bc | ||
|
|
60b2701eed | ||
|
|
3682d9844b | ||
|
|
a420f9a37c | ||
|
|
20c5107ba6 | ||
|
|
357bc91aee | ||
|
|
09653872a2 | ||
|
|
ff01a53f83 | ||
|
|
03ddd5ca9b | ||
|
|
8c49e4573c | ||
|
|
f1696ffa16 | ||
|
|
a427cb5b0c | ||
|
|
f7e4be18dd | ||
|
|
0f31c490fa | ||
|
|
c9a4a6e42b | ||
|
|
558c9df3c7 | ||
|
|
30003036d3 | ||
|
|
4b2f18c239 | ||
|
|
4290b097f5 | ||
|
|
b0f621a08b | ||
|
|
112edf41c5 | ||
|
|
74eb1d7212 | ||
|
|
e62d592b11 | ||
|
|
57a0d25321 | ||
|
|
887f79d7a5 | ||
|
|
65fd1c3ec8 | ||
|
|
6e3ee287b9 | ||
|
|
dee0b7867e | ||
|
|
77beb8044e | ||
|
|
750d3ac4ed | ||
|
|
6c02087ba4 | ||
|
|
0425283ed0 | ||
|
|
da97a57c58 |
@@ -2,6 +2,7 @@ FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0a
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
default-jre \
|
||||
fd-find \
|
||||
fzf \
|
||||
git \
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW", "--network=onyx_default"],
|
||||
"mounts": [
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
@@ -12,10 +12,13 @@
|
||||
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
|
||||
],
|
||||
"containerEnv": {
|
||||
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock"
|
||||
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock",
|
||||
"POSTGRES_HOST": "relational_db",
|
||||
"REDIS_HOST": "cache"
|
||||
},
|
||||
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
|
||||
"updateRemoteUserUID": false,
|
||||
"initializeCommand": "docker network create onyx_default 2>/dev/null || true",
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
|
||||
"workspaceFolder": "/workspace",
|
||||
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
|
||||
|
||||
@@ -4,22 +4,12 @@ set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Preserve docker dns resolution
|
||||
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
|
||||
|
||||
# Flush all rules
|
||||
iptables -t nat -F
|
||||
iptables -t nat -X
|
||||
iptables -t mangle -F
|
||||
iptables -t mangle -X
|
||||
# Only flush the filter table. The nat and mangle tables are managed by Docker
|
||||
# (DNS DNAT to 127.0.0.11, container networking, etc.) and must not be touched —
|
||||
# flushing them breaks Docker's embedded DNS resolver.
|
||||
iptables -F
|
||||
iptables -X
|
||||
|
||||
# Restore docker dns rules
|
||||
if [ -n "$DOCKER_DNS_RULES" ]; then
|
||||
echo "$DOCKER_DNS_RULES" | iptables-restore -n
|
||||
fi
|
||||
|
||||
# Create ipset for allowed destinations
|
||||
ipset create allowed-domains hash:net || true
|
||||
ipset flush allowed-domains
|
||||
@@ -34,6 +24,7 @@ done
|
||||
|
||||
# Resolve allowed domains
|
||||
ALLOWED_DOMAINS=(
|
||||
"github.com"
|
||||
"registry.npmjs.org"
|
||||
"api.anthropic.com"
|
||||
"api-staging.anthropic.com"
|
||||
@@ -65,6 +56,14 @@ if [ -n "$DOCKER_GATEWAY" ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
# Allow traffic to all attached Docker network subnets so the container can
|
||||
# reach sibling services (e.g. relational_db, cache) on shared compose networks.
|
||||
for subnet in $(ip -4 -o addr show scope global | awk '{print $4}'); do
|
||||
if ! ipset add allowed-domains "$subnet" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker subnet $subnet to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Set default policies to DROP
|
||||
iptables -P FORWARD DROP
|
||||
iptables -P INPUT DROP
|
||||
|
||||
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
@@ -475,6 +475,18 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Start Monitoring Stack (Prometheus + Grafana)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "docker",
|
||||
"runtimeArgs": ["compose", "up", "-d"],
|
||||
"cwd": "${workspaceFolder}/profiling",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
|
||||
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Onyx which \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Base stage with dependencies
|
||||
FROM python:3.11.7-slim-bookworm AS base
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47 AS base
|
||||
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
@@ -50,6 +50,10 @@ COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
|
||||
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
|
||||
COPY ./onyx/utils/tenant.py /app/onyx/utils/tenant.py
|
||||
|
||||
# Sentry configuration (used when SENTRY_DSN is set)
|
||||
COPY ./onyx/configs/__init__.py /app/onyx/configs/__init__.py
|
||||
COPY ./onyx/configs/sentry.py /app/onyx/configs/sentry.py
|
||||
|
||||
# Place to fetch version information
|
||||
COPY ./onyx/__init__.py /app/onyx/__init__.py
|
||||
|
||||
|
||||
@@ -208,7 +208,7 @@ def do_run_migrations(
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -380,7 +380,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -421,7 +421,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -464,7 +464,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
@@ -25,7 +25,7 @@ def upgrade() -> None:
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
batch_op.alter_column(
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
@@ -71,7 +71,7 @@ def downgrade() -> None:
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
batch_op.alter_column(
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
|
||||
@@ -63,7 +63,7 @@ def upgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
existing_server_default=sa.text("now()"),
|
||||
)
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
@@ -85,7 +85,7 @@ def downgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
existing_server_default=sa.text("now()"),
|
||||
)
|
||||
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
|
||||
op.drop_table("accesstoken")
|
||||
|
||||
@@ -19,7 +19,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
sequence = Sequence("connector_credential_pair_id_seq")
|
||||
op.execute(CreateSequence(sequence)) # type: ignore
|
||||
op.execute(CreateSequence(sequence))
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
|
||||
@@ -49,7 +49,7 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
@@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
target_metadata=target_metadata,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -96,11 +96,14 @@ def get_model_app() -> FastAPI:
|
||||
title="Onyx Model Server", version=__version__, lifespan=lifespan
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -10,6 +10,7 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
@@ -62,11 +63,14 @@ logger = setup_logger()
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -94,6 +98,17 @@ class TenantAwareTask(Task):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@before_task_publish.connect
|
||||
def on_before_task_publish(
|
||||
headers: dict[str, Any] | None = None,
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Stamp the current wall-clock time into the task message headers so that
|
||||
workers can compute queue wait time (time between publish and execution)."""
|
||||
if headers is not None:
|
||||
headers["enqueued_at"] = time.time()
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None, # noqa: ARG001
|
||||
|
||||
@@ -59,6 +59,11 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
@@ -300,6 +305,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
inc_deletion_blocked(tenant_id, "indexing")
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -307,11 +313,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
inc_deletion_blocked(tenant_id, "pruning")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
inc_deletion_blocked(tenant_id, "permissions")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -359,6 +367,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
inc_deletion_started(tenant_id)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -527,6 +536,12 @@ def monitor_connector_deletion_taskset(
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "success", duration)
|
||||
inc_deletion_completed(tenant_id, "success")
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
@@ -545,6 +560,11 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "failure", duration)
|
||||
inc_deletion_completed(tenant_id, "failure")
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
@@ -721,5 +741,6 @@ def validate_connector_deletion_fence(
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
inc_deletion_fence_reset(tenant_id)
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -135,10 +135,13 @@ def _docfetching_task(
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -50,6 +51,7 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -85,6 +87,8 @@ from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
@@ -105,6 +109,9 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_error_state_change
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_indexing_success
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
@@ -400,7 +407,6 @@ def check_indexing_completion(
|
||||
tenant_id: str,
|
||||
task: Task,
|
||||
) -> None:
|
||||
|
||||
logger.info(
|
||||
f"Checking for indexing completion: attempt={index_attempt_id} tenant={tenant_id}"
|
||||
)
|
||||
@@ -521,13 +527,25 @@ def check_indexing_completion(
|
||||
|
||||
# Update CC pair status if successful
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session, attempt.connector_credential_pair_id
|
||||
db_session,
|
||||
attempt.connector_credential_pair_id,
|
||||
eager_load_connector=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise RuntimeError(
|
||||
f"CC pair {attempt.connector_credential_pair_id} not found in database"
|
||||
)
|
||||
|
||||
source = cc_pair.connector.source.value
|
||||
connector_name = cc_pair.connector.name or f"cc_pair_{cc_pair.id}"
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_name=connector_name,
|
||||
status=attempt.status.value,
|
||||
)
|
||||
|
||||
if attempt.status.is_successful():
|
||||
# NOTE: we define the last successful index time as the time the last successful
|
||||
# attempt finished. This is distinct from the poll_range_end of the last successful
|
||||
@@ -548,10 +566,41 @@ def check_indexing_completion(
|
||||
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
)
|
||||
|
||||
on_connector_indexing_success(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_name=connector_name,
|
||||
docs_indexed=attempt.new_docs_indexed or 0,
|
||||
success_timestamp=attempt.time_updated.timestamp(),
|
||||
)
|
||||
|
||||
# Clear repeated error state on success
|
||||
if cc_pair.in_repeated_error_state:
|
||||
cc_pair.in_repeated_error_state = False
|
||||
|
||||
# Delete any existing error notification for this CC pair so a
|
||||
# fresh one is created if the connector fails again later.
|
||||
for notif in get_notifications(
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
|
||||
include_dismissed=True,
|
||||
):
|
||||
if (
|
||||
notif.additional_data
|
||||
and notif.additional_data.get("cc_pair_id") == cc_pair.id
|
||||
):
|
||||
db_session.delete(notif)
|
||||
|
||||
db_session.commit()
|
||||
on_connector_error_state_change(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_name=connector_name,
|
||||
in_error=False,
|
||||
)
|
||||
|
||||
if attempt.status == IndexingStatus.SUCCESS:
|
||||
logger.info(
|
||||
@@ -608,6 +657,27 @@ def active_indexing_attempt(
|
||||
return bool(active_indexing_attempt)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _KickoffResult:
|
||||
"""Tracks diagnostic counts from a _kickoff_indexing_tasks run."""
|
||||
|
||||
created: int = 0
|
||||
skipped_active: int = 0
|
||||
skipped_not_found: int = 0
|
||||
skipped_not_indexable: int = 0
|
||||
failed_to_create: int = 0
|
||||
|
||||
@property
|
||||
def evaluated(self) -> int:
|
||||
return (
|
||||
self.created
|
||||
+ self.skipped_active
|
||||
+ self.skipped_not_found
|
||||
+ self.skipped_not_indexable
|
||||
+ self.failed_to_create
|
||||
)
|
||||
|
||||
|
||||
def _kickoff_indexing_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
@@ -617,12 +687,12 @@ def _kickoff_indexing_tasks(
|
||||
redis_client: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
) -> _KickoffResult:
|
||||
"""Kick off indexing tasks for the given cc_pair_ids and search_settings.
|
||||
|
||||
Returns the number of tasks successfully created.
|
||||
Returns a _KickoffResult with diagnostic counts.
|
||||
"""
|
||||
tasks_created = 0
|
||||
result = _KickoffResult()
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
@@ -633,6 +703,7 @@ def _kickoff_indexing_tasks(
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
):
|
||||
result.skipped_active += 1
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -643,6 +714,7 @@ def _kickoff_indexing_tasks(
|
||||
task_logger.warning(
|
||||
f"_kickoff_indexing_tasks - CC pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
result.skipped_not_found += 1
|
||||
continue
|
||||
|
||||
# Heavyweight check after fetching cc pair
|
||||
@@ -657,6 +729,7 @@ def _kickoff_indexing_tasks(
|
||||
f"search_settings={search_settings.id}, "
|
||||
f"secondary_index_building={secondary_index_building}"
|
||||
)
|
||||
result.skipped_not_indexable += 1
|
||||
continue
|
||||
|
||||
task_logger.debug(
|
||||
@@ -696,13 +769,14 @@ def _kickoff_indexing_tasks(
|
||||
task_logger.info(
|
||||
f"Connector indexing queued: index_attempt={attempt_id} cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
result.created += 1
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Failed to create indexing task: cc_pair={cc_pair.id} search_settings={search_settings.id}"
|
||||
)
|
||||
result.failed_to_create += 1
|
||||
|
||||
return tasks_created
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -728,6 +802,8 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
task_logger.warning("check_for_indexing - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
primary_result = _KickoffResult()
|
||||
secondary_result: _KickoffResult | None = None
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
redis_client_replica = get_redis_replica_client()
|
||||
@@ -848,6 +924,43 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
error_connector_name = (
|
||||
cc_pair.connector.name or f"cc_pair_{cc_pair.id}"
|
||||
)
|
||||
on_connector_error_state_change(
|
||||
tenant_id=tenant_id,
|
||||
source=cc_pair.connector.source.value,
|
||||
cc_pair_id=cc_pair_id,
|
||||
connector_name=error_connector_name,
|
||||
in_error=True,
|
||||
)
|
||||
|
||||
connector_name = (
|
||||
cc_pair.name
|
||||
or cc_pair.connector.name
|
||||
or f"CC pair {cc_pair.id}"
|
||||
)
|
||||
source = cc_pair.connector.source.value
|
||||
connector_url = f"/admin/connector/{cc_pair.id}"
|
||||
create_notification(
|
||||
user_id=None,
|
||||
notif_type=NotificationType.CONNECTOR_REPEATED_ERRORS,
|
||||
db_session=db_session,
|
||||
title=f"Connector '{connector_name}' has entered repeated error state",
|
||||
description=(
|
||||
f"The {source} connector has failed repeatedly and "
|
||||
f"has been flagged. View indexing history in the "
|
||||
f"Advanced section: {connector_url}"
|
||||
),
|
||||
additional_data={"cc_pair_id": cc_pair.id},
|
||||
)
|
||||
|
||||
task_logger.error(
|
||||
f"Connector entered repeated error state: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"connector={cc_pair.connector.name} "
|
||||
f"source={source}"
|
||||
)
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
@@ -863,7 +976,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Primary first
|
||||
tasks_created += _kickoff_indexing_tasks(
|
||||
primary_result = _kickoff_indexing_tasks(
|
||||
celery_app=self.app,
|
||||
db_session=db_session,
|
||||
search_settings=current_search_settings,
|
||||
@@ -873,6 +986,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat=lock_beat,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
tasks_created += primary_result.created
|
||||
|
||||
# Secondary indexing (only if secondary search settings exist and switchover_type is not INSTANT)
|
||||
if (
|
||||
@@ -880,7 +994,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
and secondary_search_settings.switchover_type != SwitchoverType.INSTANT
|
||||
and secondary_cc_pair_ids
|
||||
):
|
||||
tasks_created += _kickoff_indexing_tasks(
|
||||
secondary_result = _kickoff_indexing_tasks(
|
||||
celery_app=self.app,
|
||||
db_session=db_session,
|
||||
search_settings=secondary_search_settings,
|
||||
@@ -890,6 +1004,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat=lock_beat,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
tasks_created += secondary_result.created
|
||||
elif (
|
||||
secondary_search_settings
|
||||
and secondary_search_settings.switchover_type == SwitchoverType.INSTANT
|
||||
@@ -1002,7 +1117,26 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
task_logger.info(
|
||||
f"check_for_indexing finished: "
|
||||
f"elapsed={time_elapsed:.2f}s "
|
||||
f"primary=[evaluated={primary_result.evaluated} "
|
||||
f"created={primary_result.created} "
|
||||
f"skipped_active={primary_result.skipped_active} "
|
||||
f"skipped_not_found={primary_result.skipped_not_found} "
|
||||
f"skipped_not_indexable={primary_result.skipped_not_indexable} "
|
||||
f"failed={primary_result.failed_to_create}]"
|
||||
+ (
|
||||
f" secondary=[evaluated={secondary_result.evaluated} "
|
||||
f"created={secondary_result.created} "
|
||||
f"skipped_active={secondary_result.skipped_active} "
|
||||
f"skipped_not_found={secondary_result.skipped_not_found} "
|
||||
f"skipped_not_indexable={secondary_result.skipped_not_indexable} "
|
||||
f"failed={secondary_result.failed_to_create}]"
|
||||
if secondary_result
|
||||
else ""
|
||||
)
|
||||
)
|
||||
return tasks_created
|
||||
|
||||
|
||||
|
||||
@@ -172,6 +172,10 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
task_logger.debug(
|
||||
"Verified tenant info, migration record, and search settings."
|
||||
)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
@@ -325,6 +329,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
task_logger.debug("Released the OpenSearch migration lock.")
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the migration task."
|
||||
|
||||
@@ -51,7 +51,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
|
||||
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
@@ -643,16 +642,6 @@ def connector_pruning_generator_task(
|
||||
raw_id_to_parent=all_connector_doc_ids,
|
||||
)
|
||||
|
||||
# Link hierarchy nodes to documents for sources where pages can be
|
||||
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
|
||||
all_doc_id_list = list(all_connector_doc_ids.keys())
|
||||
link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=all_doc_id_list,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
diff_start = time.monotonic()
|
||||
try:
|
||||
# a list of docs in our local index
|
||||
|
||||
@@ -248,6 +248,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
),
|
||||
)
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
db_session.commit()
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -68,6 +69,7 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
get_persistent_document_writer,
|
||||
)
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
@@ -267,6 +269,14 @@ def run_docfetching_entrypoint(
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=attempt.connector_credential_pair.connector.source.value,
|
||||
cc_pair_id=connector_credential_pair_id,
|
||||
connector_name=connector_name or f"cc_pair_{connector_credential_pair_id}",
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Docfetching starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
@@ -556,6 +566,27 @@ def connector_document_extraction(
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
if failure.exception is not None:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "connector_fetch")
|
||||
scope.set_tag("connector_source", db_connector.source.value)
|
||||
scope.set_tag("cc_pair_id", str(cc_pair_id))
|
||||
scope.set_tag("index_attempt_id", str(index_attempt_id))
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
if failure.failed_document:
|
||||
scope.set_tag(
|
||||
"doc_id", failure.failed_document.document_id
|
||||
)
|
||||
if failure.failed_entity:
|
||||
scope.set_tag(
|
||||
"entity_id", failure.failed_entity.entity_id
|
||||
)
|
||||
scope.fingerprint = [
|
||||
"connector-fetch-failure",
|
||||
db_connector.source.value,
|
||||
type(failure.exception).__name__,
|
||||
]
|
||||
sentry_sdk.capture_exception(failure.exception)
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_index_attempt_error(
|
||||
|
||||
@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
logger.info(f"Cache miss for file with id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
|
||||
@@ -4,8 +4,6 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -635,7 +633,6 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1020,20 +1017,16 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
memory = update_memory_at_index(
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
memory = add_memory(
|
||||
persisted_memory_id = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -67,7 +67,6 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -1006,93 +1005,86 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
|
||||
@@ -1125,6 +1125,32 @@ DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
||||
# Number of pre-provisioned tenants to maintain
|
||||
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
|
||||
|
||||
# Master switch for the tenant work-gating feature. Controls the `enabled`
|
||||
# axis only — flipping this True puts the feature in shadow mode (compute
|
||||
# the gate, log skip counts, but do not actually skip). The `enforce` axis
|
||||
# is Redis-only with a hard-coded default of False, so this env flag alone
|
||||
# cannot cause real tenants to be skipped. Default off.
|
||||
ENABLE_TENANT_WORK_GATING = (
|
||||
os.environ.get("ENABLE_TENANT_WORK_GATING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Membership TTL for the `active_tenants` sorted set. Members older than this
|
||||
# are treated as inactive by the gate read path. Must be > the full-fanout
|
||||
# interval so self-healing re-adds a genuinely-working tenant before their
|
||||
# membership expires. Default 30 min.
|
||||
TENANT_WORK_GATING_TTL_SECONDS = int(
|
||||
os.environ.get("TENANT_WORK_GATING_TTL_SECONDS", 30 * 60)
|
||||
)
|
||||
|
||||
# Minimum wall-clock interval between full-fanout cycles. When this many
|
||||
# seconds have elapsed since the last bypass, the generator ignores the gate
|
||||
# on the next invocation and dispatches to every non-gated tenant, letting
|
||||
# consumers re-populate the active set. Schedule-independent so beat drift
|
||||
# or backlog can't make the self-heal bursty or sparse. Default 20 min.
|
||||
TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS = int(
|
||||
os.environ.get("TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS", 20 * 60)
|
||||
)
|
||||
|
||||
|
||||
# Image summarization configuration
|
||||
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
|
||||
|
||||
@@ -283,6 +283,7 @@ class NotificationType(str, Enum):
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
FEATURE_ANNOUNCEMENT = "feature_announcement"
|
||||
CONNECTOR_REPEATED_ERRORS = "connector_repeated_errors"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
|
||||
48
backend/onyx/configs/sentry.py
Normal file
48
backend/onyx/configs/sentry.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Any
|
||||
|
||||
from sentry_sdk.types import Event
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_instance_id_resolved = False
|
||||
|
||||
|
||||
def _add_instance_tags(
|
||||
event: Event,
|
||||
hint: dict[str, Any], # noqa: ARG001
|
||||
) -> Event | None:
|
||||
"""Sentry before_send hook that lazily attaches instance identification tags.
|
||||
|
||||
On the first event, resolves the instance UUID from the KV store (requires DB)
|
||||
and sets it as a global Sentry tag. Subsequent events pick it up automatically.
|
||||
"""
|
||||
global _instance_id_resolved
|
||||
|
||||
if _instance_id_resolved:
|
||||
return event
|
||||
|
||||
try:
|
||||
import sentry_sdk
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
if MULTI_TENANT:
|
||||
instance_id = "multi-tenant-cloud"
|
||||
else:
|
||||
from onyx.utils.telemetry import get_or_generate_uuid
|
||||
|
||||
instance_id = get_or_generate_uuid()
|
||||
|
||||
sentry_sdk.set_tag("instance_id", instance_id)
|
||||
|
||||
# Also set on this event since set_tag won't retroactively apply
|
||||
event.setdefault("tags", {})["instance_id"] = instance_id
|
||||
|
||||
# Only mark resolved after success — if DB wasn't ready, retry next event
|
||||
_instance_id_resolved = True
|
||||
except Exception:
|
||||
logger.debug("Failed to resolve instance_id for Sentry tagging")
|
||||
|
||||
return event
|
||||
@@ -26,6 +26,10 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
process_onyx_metadata,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
|
||||
tabular_file_to_sections,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
@@ -38,6 +42,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -451,6 +456,40 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
logger.exception(f"Error processing image {key}")
|
||||
continue
|
||||
|
||||
# Handle tabular files (xlsx, csv, tsv) — produce one
|
||||
# TabularSection per sheet (or per file for csv/tsv)
|
||||
# instead of a flat TextSection.
|
||||
if is_tabular_file(file_name):
|
||||
try:
|
||||
downloaded_file = self._download_object(key)
|
||||
if downloaded_file is None:
|
||||
continue
|
||||
tabular_sections = tabular_file_to_sections(
|
||||
BytesIO(downloaded_file),
|
||||
file_name=file_name,
|
||||
link=link,
|
||||
)
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
sections=(
|
||||
tabular_sections
|
||||
if tabular_sections
|
||||
else [TabularSection(link=link, text="")]
|
||||
),
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
except Exception:
|
||||
logger.exception(f"Error processing tabular file {key}")
|
||||
continue
|
||||
|
||||
# Handle text and document files
|
||||
try:
|
||||
downloaded_file = self._download_object(key)
|
||||
|
||||
@@ -27,16 +27,19 @@ _STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
|
||||
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
|
||||
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: OnyxErrorCode.BAD_GATEWAY,
|
||||
429: OnyxErrorCode.RATE_LIMITED,
|
||||
}
|
||||
|
||||
|
||||
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
|
||||
"""Map an HTTP status code to the appropriate OnyxErrorCode.
|
||||
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404) are
|
||||
mapped to specific error codes; all other codes (unrecognised 4xx
|
||||
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
|
||||
|
||||
Note: 429 is intentionally omitted — the rl_requests wrapper
|
||||
handles rate limits transparently at the HTTP layer, so 429
|
||||
responses never reach this function.
|
||||
"""
|
||||
if status_code in _STATUS_TO_ERROR_CODE:
|
||||
return _STATUS_TO_ERROR_CODE[status_code]
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import NoReturn
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
@@ -25,8 +24,11 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
@@ -47,10 +49,6 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
raise InsufficientPermissionsError(
|
||||
"Canvas API token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
|
||||
)
|
||||
elif e.status_code >= 500:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
|
||||
@@ -61,6 +59,60 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
)
|
||||
|
||||
|
||||
class CanvasStage(StrEnum):
|
||||
PAGES = "pages"
|
||||
ASSIGNMENTS = "assignments"
|
||||
ANNOUNCEMENTS = "announcements"
|
||||
|
||||
|
||||
_STAGE_CONFIG: dict[CanvasStage, dict[str, Any]] = {
|
||||
CanvasStage.PAGES: {
|
||||
"endpoint": "courses/{course_id}/pages",
|
||||
"params": {
|
||||
"per_page": "100",
|
||||
"include[]": "body",
|
||||
"published": "true",
|
||||
"sort": "updated_at",
|
||||
"order": "desc",
|
||||
},
|
||||
},
|
||||
CanvasStage.ASSIGNMENTS: {
|
||||
"endpoint": "courses/{course_id}/assignments",
|
||||
"params": {"per_page": "100", "published": "true"},
|
||||
},
|
||||
CanvasStage.ANNOUNCEMENTS: {
|
||||
"endpoint": "announcements",
|
||||
"params": {
|
||||
"per_page": "100",
|
||||
"context_codes[]": "course_{course_id}",
|
||||
"active_only": "true",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _parse_canvas_dt(timestamp_str: str) -> datetime:
|
||||
"""Parse a Canvas ISO-8601 timestamp (e.g. '2025-06-15T12:00:00Z')
|
||||
into a timezone-aware UTC datetime.
|
||||
|
||||
Canvas returns timestamps with a trailing 'Z' instead of '+00:00',
|
||||
so we normalise before parsing.
|
||||
"""
|
||||
return datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
|
||||
|
||||
def _unix_to_canvas_time(epoch: float) -> str:
|
||||
"""Convert a Unix timestamp to Canvas ISO-8601 format (e.g. '2025-06-15T12:00:00Z')."""
|
||||
return datetime.fromtimestamp(epoch, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def _in_time_window(timestamp_str: str, start: float, end: float) -> bool:
|
||||
"""Check whether a Canvas ISO-8601 timestamp falls within (start, end]."""
|
||||
return start < _parse_canvas_dt(timestamp_str).timestamp() <= end
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
@@ -145,9 +197,6 @@ class CanvasAnnouncement(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
|
||||
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint state for resumable Canvas indexing.
|
||||
|
||||
@@ -165,15 +214,30 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
course_ids: list[int] = []
|
||||
current_course_index: int = 0
|
||||
stage: CanvasStage = "pages"
|
||||
stage: CanvasStage = CanvasStage.PAGES
|
||||
next_url: str | None = None
|
||||
|
||||
def advance_course(self) -> None:
|
||||
"""Move to the next course and reset within-course state."""
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.stage = CanvasStage.PAGES
|
||||
self.next_url = None
|
||||
|
||||
def advance_stage(self) -> None:
|
||||
"""Advance past the current stage.
|
||||
|
||||
Moves to the next stage within the same course, or to the next
|
||||
course if the current stage is the last one. Resets next_url so
|
||||
the next call starts fresh on the new stage.
|
||||
"""
|
||||
self.next_url = None
|
||||
stages: list[CanvasStage] = list(CanvasStage)
|
||||
next_idx = stages.index(self.stage) + 1
|
||||
if next_idx < len(stages):
|
||||
self.stage = stages[next_idx]
|
||||
else:
|
||||
self.advance_course()
|
||||
|
||||
|
||||
class CanvasConnector(
|
||||
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
|
||||
@@ -295,13 +359,7 @@ class CanvasConnector(
|
||||
if body_text:
|
||||
text_parts.append(body_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
if page.updated_at
|
||||
else None
|
||||
)
|
||||
doc_updated_at = _parse_canvas_dt(page.updated_at) if page.updated_at else None
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
|
||||
@@ -325,17 +383,11 @@ class CanvasConnector(
|
||||
if desc_text:
|
||||
text_parts.append(desc_text)
|
||||
if assignment.due_at:
|
||||
due_dt = datetime.fromisoformat(
|
||||
assignment.due_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
due_dt = _parse_canvas_dt(assignment.due_at)
|
||||
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
assignment.updated_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if assignment.updated_at
|
||||
else None
|
||||
_parse_canvas_dt(assignment.updated_at) if assignment.updated_at else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
@@ -361,11 +413,7 @@ class CanvasConnector(
|
||||
text_parts.append(msg_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
announcement.posted_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if announcement.posted_at
|
||||
else None
|
||||
_parse_canvas_dt(announcement.posted_at) if announcement.posted_at else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
@@ -400,6 +448,314 @@ class CanvasConnector(
|
||||
self._canvas_client = client
|
||||
return None
|
||||
|
||||
def _fetch_stage_page(
|
||||
self,
|
||||
next_url: str | None,
|
||||
endpoint: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[list[Any], str | None]:
|
||||
"""Fetch one page of API results for the current stage.
|
||||
|
||||
Returns (items, next_url). All error handling is done by the
|
||||
caller (_load_from_checkpoint).
|
||||
"""
|
||||
if next_url:
|
||||
# Resuming mid-pagination: the next_url from Canvas's
|
||||
# Link header already contains endpoint + query params.
|
||||
response, result_next_url = self.canvas_client.get(full_url=next_url)
|
||||
else:
|
||||
# First request for this stage: build from endpoint + params.
|
||||
response, result_next_url = self.canvas_client.get(
|
||||
endpoint=endpoint, params=params
|
||||
)
|
||||
return response or [], result_next_url
|
||||
|
||||
def _process_items(
|
||||
self,
|
||||
response: list[Any],
|
||||
stage: CanvasStage,
|
||||
course_id: int,
|
||||
start: float,
|
||||
end: float,
|
||||
include_permissions: bool,
|
||||
) -> tuple[list[Document | ConnectorFailure], bool]:
|
||||
"""Process a page of API results into documents.
|
||||
|
||||
Returns (docs, early_exit). early_exit is True when pages
|
||||
(sorted desc by updated_at) hit an item older than start,
|
||||
signaling that pagination should stop.
|
||||
"""
|
||||
results: list[Document | ConnectorFailure] = []
|
||||
early_exit = False
|
||||
|
||||
for item in response:
|
||||
try:
|
||||
if stage == CanvasStage.PAGES:
|
||||
page = CanvasPage.from_api(item, course_id=course_id)
|
||||
if not page.updated_at:
|
||||
continue
|
||||
# Pages are sorted by updated_at desc — once we see
|
||||
# an item at or before `start`, all remaining items
|
||||
# on this and subsequent pages are older too.
|
||||
if not _in_time_window(page.updated_at, start, end):
|
||||
if _parse_canvas_dt(page.updated_at).timestamp() <= start:
|
||||
early_exit = True
|
||||
break
|
||||
# ts > end: page is newer than our window, skip it
|
||||
continue
|
||||
doc = self._convert_page_to_document(page)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
elif stage == CanvasStage.ASSIGNMENTS:
|
||||
assignment = CanvasAssignment.from_api(item, course_id=course_id)
|
||||
if not assignment.updated_at or not _in_time_window(
|
||||
assignment.updated_at, start, end
|
||||
):
|
||||
continue
|
||||
doc = self._convert_assignment_to_document(assignment)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
elif stage == CanvasStage.ANNOUNCEMENTS:
|
||||
announcement = CanvasAnnouncement.from_api(
|
||||
item, course_id=course_id
|
||||
)
|
||||
if not announcement.posted_at:
|
||||
logger.debug(
|
||||
f"Skipping announcement {announcement.id} in "
|
||||
f"course {course_id}: no posted_at"
|
||||
)
|
||||
continue
|
||||
if not _in_time_window(announcement.posted_at, start, end):
|
||||
continue
|
||||
doc = self._convert_announcement_to_document(announcement)
|
||||
results.append(
|
||||
self._maybe_attach_permissions(
|
||||
doc, course_id, include_permissions
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
item_id = item.get("id") or item.get("page_id", "unknown")
|
||||
if stage == CanvasStage.PAGES:
|
||||
doc_link = (
|
||||
f"{self.canvas_base_url}/courses/{course_id}"
|
||||
f"/pages/{item.get('url', '')}"
|
||||
)
|
||||
else:
|
||||
doc_link = item.get("html_url", "")
|
||||
results.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"canvas-{stage.removesuffix('s')}-{course_id}-{item_id}",
|
||||
document_link=doc_link,
|
||||
),
|
||||
failure_message=f"Failed to process {stage.removesuffix('s')}: {e}",
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return results, early_exit
|
||||
|
||||
def _maybe_attach_permissions(
|
||||
self,
|
||||
document: Document,
|
||||
course_id: int,
|
||||
include_permissions: bool,
|
||||
) -> Document:
|
||||
if include_permissions:
|
||||
document.external_access = self._get_course_permissions(course_id)
|
||||
return document
|
||||
|
||||
def _load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
include_permissions: bool = False,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
"""Shared implementation for load_from_checkpoint and load_from_checkpoint_with_perm_sync."""
|
||||
new_checkpoint = checkpoint.model_copy(deep=True)
|
||||
|
||||
# First call: materialize the list of course IDs.
|
||||
# On failure, let the exception propagate so the framework fails the
|
||||
# attempt cleanly. Swallowing errors here would leave the checkpoint
|
||||
# state unchanged and cause an infinite retry loop.
|
||||
if not new_checkpoint.course_ids:
|
||||
try:
|
||||
courses = self._list_courses()
|
||||
except OnyxError as e:
|
||||
if e.status_code in (401, 403):
|
||||
_handle_canvas_api_error(e) # NoReturn — always raises
|
||||
raise
|
||||
new_checkpoint.course_ids = [c.id for c in courses]
|
||||
logger.info(f"Found {len(courses)} Canvas courses to process")
|
||||
new_checkpoint.has_more = len(new_checkpoint.course_ids) > 0
|
||||
return new_checkpoint
|
||||
|
||||
# All courses done.
|
||||
if new_checkpoint.current_course_index >= len(new_checkpoint.course_ids):
|
||||
new_checkpoint.has_more = False
|
||||
return new_checkpoint
|
||||
|
||||
course_id = new_checkpoint.course_ids[new_checkpoint.current_course_index]
|
||||
try:
|
||||
stage = CanvasStage(new_checkpoint.stage)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid checkpoint stage: {new_checkpoint.stage!r}. "
|
||||
f"Valid stages: {[s.value for s in CanvasStage]}"
|
||||
) from e
|
||||
|
||||
# Build endpoint + params from the static template.
|
||||
config = _STAGE_CONFIG[stage]
|
||||
endpoint = config["endpoint"].format(course_id=course_id)
|
||||
params = {k: v.format(course_id=course_id) for k, v in config["params"].items()}
|
||||
# Only the announcements API supports server-side date filtering
|
||||
# (start_date/end_date). Pages support server-side sorting
|
||||
# (sort=updated_at desc) enabling early exit, but not date
|
||||
# filtering. Assignments support neither. Both are filtered
|
||||
# client-side via _in_time_window after fetching.
|
||||
if stage == CanvasStage.ANNOUNCEMENTS:
|
||||
params["start_date"] = _unix_to_canvas_time(start)
|
||||
params["end_date"] = _unix_to_canvas_time(end)
|
||||
|
||||
try:
|
||||
response, result_next_url = self._fetch_stage_page(
|
||||
next_url=new_checkpoint.next_url,
|
||||
endpoint=endpoint,
|
||||
params=params,
|
||||
)
|
||||
except OnyxError as oe:
|
||||
# Security errors from _parse_next_link (host/scheme
|
||||
# mismatch on pagination URLs) have no status code override
|
||||
# and must not be silenced.
|
||||
is_api_error = oe._status_code_override is not None
|
||||
if not is_api_error:
|
||||
raise
|
||||
if oe.status_code in (401, 403):
|
||||
_handle_canvas_api_error(oe) # NoReturn — always raises
|
||||
|
||||
# 404 means the course itself is gone or inaccessible. The
|
||||
# other stages on this course will hit the same 404, so skip
|
||||
# the whole course rather than burning API calls on each stage.
|
||||
if oe.status_code == 404:
|
||||
logger.warning(
|
||||
f"Canvas course {course_id} not found while fetching "
|
||||
f"{stage} (HTTP 404). Skipping course."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-course-{course_id}",
|
||||
),
|
||||
failure_message=(f"Canvas course {course_id} not found: {oe}"),
|
||||
exception=oe,
|
||||
)
|
||||
new_checkpoint.advance_course()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to fetch {stage} for course {course_id}: {oe}. "
|
||||
f"Skipping remainder of this stage."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-{stage}-{course_id}",
|
||||
),
|
||||
failure_message=(
|
||||
f"Failed to fetch {stage} for course {course_id}: {oe}"
|
||||
),
|
||||
exception=oe,
|
||||
)
|
||||
new_checkpoint.advance_stage()
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
except Exception as e:
|
||||
# Unknown error — skip the stage and try to continue.
|
||||
logger.warning(
|
||||
f"Failed to fetch {stage} for course {course_id}: {e}. "
|
||||
f"Skipping remainder of this stage."
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=f"canvas-{stage}-{course_id}",
|
||||
),
|
||||
failure_message=(
|
||||
f"Failed to fetch {stage} for course {course_id}: {e}"
|
||||
),
|
||||
exception=e,
|
||||
)
|
||||
new_checkpoint.advance_stage()
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
|
||||
# Process fetched items
|
||||
results, early_exit = self._process_items(
|
||||
response, stage, course_id, start, end, include_permissions
|
||||
)
|
||||
for result in results:
|
||||
yield result
|
||||
|
||||
# If we hit an item older than our window (pages sorted desc),
|
||||
# skip remaining pagination and advance to the next stage.
|
||||
if early_exit:
|
||||
result_next_url = None
|
||||
|
||||
# If there are more pages, save the cursor and return
|
||||
if result_next_url:
|
||||
new_checkpoint.next_url = result_next_url
|
||||
else:
|
||||
# Stage complete — advance to next stage (or next course if last).
|
||||
new_checkpoint.advance_stage()
|
||||
|
||||
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
|
||||
new_checkpoint.course_ids
|
||||
)
|
||||
return new_checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=False
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint with permission information included."""
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
return CanvasConnectorCheckpoint(has_more=True)
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
return CanvasConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Canvas connector settings by testing API access."""
|
||||
@@ -415,38 +771,6 @@ class CanvasConnector(
|
||||
f"Unexpected error during Canvas settings validation: {exc}"
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -171,7 +171,10 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
document.metadata[extra_field] = task[extra_field]
|
||||
|
||||
if self.retrieve_task_comments:
|
||||
document.sections.extend(self._get_task_comments(task["id"]))
|
||||
document.sections = [
|
||||
*document.sections,
|
||||
*self._get_task_comments(task["id"]),
|
||||
]
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
import csv
|
||||
import io
|
||||
from typing import IO
|
||||
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.file_processing.extract_file_text import file_io_to_text
|
||||
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_tabular_file(file_name: str) -> bool:
|
||||
lowered = file_name.lower()
|
||||
return any(lowered.endswith(ext) for ext in OnyxFileExtensions.TABULAR_EXTENSIONS)
|
||||
|
||||
|
||||
def _tsv_to_csv(tsv_text: str) -> str:
|
||||
"""Re-serialize tab-separated text as CSV so downstream parsers that
|
||||
assume the default Excel dialect read the columns correctly."""
|
||||
out = io.StringIO()
|
||||
csv.writer(out, lineterminator="\n").writerows(
|
||||
csv.reader(io.StringIO(tsv_text), dialect="excel-tab")
|
||||
)
|
||||
return out.getvalue().rstrip("\n")
|
||||
|
||||
|
||||
def tabular_file_to_sections(
|
||||
file: IO[bytes],
|
||||
file_name: str,
|
||||
link: str = "",
|
||||
) -> list[TabularSection]:
|
||||
"""Convert a tabular file into one or more TabularSections.
|
||||
|
||||
- .xlsx → one TabularSection per non-empty sheet.
|
||||
- .csv / .tsv → a single TabularSection containing the full decoded
|
||||
file.
|
||||
|
||||
Returns an empty list when the file yields no extractable content.
|
||||
"""
|
||||
lowered = file_name.lower()
|
||||
|
||||
if lowered.endswith(".xlsx"):
|
||||
return [
|
||||
TabularSection(
|
||||
link=link or file_name,
|
||||
text=csv_text,
|
||||
heading=f"{file_name} :: {sheet_title}",
|
||||
)
|
||||
for csv_text, sheet_title in xlsx_sheet_extraction(
|
||||
file, file_name=file_name
|
||||
)
|
||||
]
|
||||
|
||||
if not lowered.endswith((".csv", ".tsv")):
|
||||
raise ValueError(f"{file_name!r} is not a tabular file")
|
||||
|
||||
try:
|
||||
text = file_io_to_text(file).strip()
|
||||
except Exception:
|
||||
logger.exception(f"Failure decoding {file_name}")
|
||||
raise
|
||||
|
||||
if not text:
|
||||
return []
|
||||
if lowered.endswith(".tsv"):
|
||||
text = _tsv_to_csv(text)
|
||||
return [TabularSection(link=link or file_name, text=text)]
|
||||
@@ -15,6 +15,10 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rate_limit_builder
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rl_requests
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
|
||||
tabular_file_to_sections,
|
||||
)
|
||||
from onyx.connectors.drupal_wiki.models import DrupalWikiCheckpoint
|
||||
from onyx.connectors.drupal_wiki.models import DrupalWikiPage
|
||||
from onyx.connectors.drupal_wiki.models import DrupalWikiPageResponse
|
||||
@@ -33,6 +37,7 @@ from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -213,7 +218,7 @@ class DrupalWikiConnector(
|
||||
attachment: dict[str, Any],
|
||||
page_id: int,
|
||||
download_url: str,
|
||||
) -> tuple[list[TextSection | ImageSection], str | None]:
|
||||
) -> tuple[list[TextSection | ImageSection | TabularSection], str | None]:
|
||||
"""
|
||||
Process a single attachment and return generated sections.
|
||||
|
||||
@@ -226,7 +231,7 @@ class DrupalWikiConnector(
|
||||
Tuple of (sections, error_message). If error_message is not None, the
|
||||
sections list should be treated as invalid.
|
||||
"""
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
sections: list[TextSection | ImageSection | TabularSection] = []
|
||||
|
||||
try:
|
||||
if not self._validate_attachment_filetype(attachment):
|
||||
@@ -273,6 +278,25 @@ class DrupalWikiConnector(
|
||||
|
||||
return sections, None
|
||||
|
||||
# Tabular attachments (xlsx, csv, tsv) — produce
|
||||
# TabularSections instead of a flat TextSection.
|
||||
if is_tabular_file(file_name):
|
||||
try:
|
||||
sections.extend(
|
||||
tabular_file_to_sections(
|
||||
BytesIO(raw_bytes),
|
||||
file_name=file_name,
|
||||
link=download_url,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to extract tabular sections from {file_name}"
|
||||
)
|
||||
if not sections:
|
||||
return [], f"No content extracted from tabular file {file_name}"
|
||||
return sections, None
|
||||
|
||||
image_counter = 0
|
||||
|
||||
def _store_embedded_image(image_data: bytes, image_name: str) -> None:
|
||||
@@ -497,7 +521,7 @@ class DrupalWikiConnector(
|
||||
page_url = build_drupal_wiki_document_id(self.base_url, page.id)
|
||||
|
||||
# Create sections with just the page content
|
||||
sections: list[TextSection | ImageSection] = [
|
||||
sections: list[TextSection | ImageSection | TabularSection] = [
|
||||
TextSection(text=text_content, link=page_url)
|
||||
]
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
@@ -12,11 +13,16 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
process_onyx_metadata,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
|
||||
tabular_file_to_sections,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -179,8 +185,32 @@ def _process_file(
|
||||
link = onyx_metadata.link or link
|
||||
|
||||
# Build sections: first the text as a single Section
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
if extraction_result.text_content.strip():
|
||||
sections: list[TextSection | ImageSection | TabularSection] = []
|
||||
if is_tabular_file(file_name):
|
||||
# Produce TabularSections
|
||||
lowered_name = file_name.lower()
|
||||
if lowered_name.endswith(".xlsx"):
|
||||
file.seek(0)
|
||||
tabular_source: IO[bytes] = file
|
||||
else:
|
||||
tabular_source = BytesIO(
|
||||
extraction_result.text_content.encode("utf-8", errors="replace")
|
||||
)
|
||||
try:
|
||||
sections.extend(
|
||||
tabular_file_to_sections(
|
||||
file=tabular_source,
|
||||
file_name=file_name,
|
||||
link=link or "",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process tabular file {file_name}: {e}")
|
||||
return []
|
||||
if not sections:
|
||||
logger.warning(f"No content extracted from tabular file {file_name}")
|
||||
return []
|
||||
elif extraction_result.text_content.strip():
|
||||
logger.debug(f"Creating TextSection for {file_name} with link: {link}")
|
||||
sections.append(
|
||||
TextSection(link=link, text=extraction_result.text_content.strip())
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing_extensions import override
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.connector_runner import CheckpointOutputWrapper
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
@@ -35,10 +36,16 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import IndexingHeartbeatInterface
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -427,7 +434,11 @@ def make_cursor_url_callback(
|
||||
return cursor_url_callback
|
||||
|
||||
|
||||
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
|
||||
class GithubConnector(
|
||||
CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
@@ -559,6 +570,7 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
include_permissions: bool = False,
|
||||
is_slim: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
@@ -614,36 +626,46 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
for pr in pr_batch:
|
||||
num_prs += 1
|
||||
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
done_with_prs = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
continue
|
||||
try:
|
||||
yield _convert_pr_to_document(
|
||||
cast(PullRequest, pr), repo_external_access
|
||||
if is_slim:
|
||||
yield Document(
|
||||
id=pr.html_url,
|
||||
sections=[],
|
||||
external_access=repo_external_access,
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier="",
|
||||
metadata={},
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(pr.id), document_link=pr.html_url
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
done_with_prs = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
continue
|
||||
try:
|
||||
yield _convert_pr_to_document(
|
||||
cast(PullRequest, pr), repo_external_access
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(pr.id), document_link=pr.html_url
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
# If we reach this point with a cursor url in the checkpoint, we were using
|
||||
# the fallback cursor-based pagination strategy. That strategy tries to get all
|
||||
@@ -689,38 +711,47 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
for issue in issue_batch:
|
||||
num_issues += 1
|
||||
issue = cast(Issue, issue)
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
done_with_issues = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
continue
|
||||
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
|
||||
try:
|
||||
yield _convert_issue_to_document(issue, repo_external_access)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(issue.id),
|
||||
document_link=issue.html_url,
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
if is_slim:
|
||||
yield Document(
|
||||
id=issue.html_url,
|
||||
sections=[],
|
||||
external_access=repo_external_access,
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier="",
|
||||
metadata={},
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# we iterate backwards in time, so at this point we stop processing issues
|
||||
if (
|
||||
start is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
done_with_issues = True
|
||||
break
|
||||
# Skip issues updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
continue
|
||||
try:
|
||||
yield _convert_issue_to_document(issue, repo_external_access)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(issue.id),
|
||||
document_link=issue.html_url,
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Fetched {num_issues} issues for repo: {repo.name}")
|
||||
# if we found any issues on the page, and we're not done, return the checkpoint.
|
||||
@@ -803,6 +834,60 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
def _retrieve_slim_docs(
|
||||
self,
|
||||
include_permissions: bool,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""Iterate all PRs and issues across all configured repos as SlimDocuments.
|
||||
|
||||
Drives _fetch_from_github in a checkpoint loop — each call processes one
|
||||
page and returns an updated checkpoint. CheckpointOutputWrapper handles
|
||||
draining the generator and extracting the returned checkpoint. Rate
|
||||
limiting and pagination are handled centrally by _fetch_from_github via
|
||||
_get_batch_rate_limited.
|
||||
"""
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
while checkpoint.has_more:
|
||||
batch: list[SlimDocument | HierarchyNode] = []
|
||||
gen = self._fetch_from_github(
|
||||
checkpoint, include_permissions=include_permissions, is_slim=True
|
||||
)
|
||||
wrapper: CheckpointOutputWrapper[GithubConnectorCheckpoint] = (
|
||||
CheckpointOutputWrapper()
|
||||
)
|
||||
for document, _, _, next_checkpoint in wrapper(gen):
|
||||
if document is not None:
|
||||
batch.append(
|
||||
SlimDocument(
|
||||
id=document.id, external_access=document.external_access
|
||||
)
|
||||
)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
if batch:
|
||||
yield batch
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError("github_slim_docs: Stop signal detected")
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_slim_docs(include_permissions=False, callback=callback)
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_slim_docs(include_permissions=True, callback=callback)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
|
||||
|
||||
@@ -75,6 +75,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -207,6 +208,7 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
@@ -1754,6 +1756,7 @@ class GoogleDriveConnector(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
slim_batch: list[SlimDocument | HierarchyNode] = []
|
||||
@@ -1763,9 +1766,13 @@ class GoogleDriveConnector(
|
||||
nonlocal files_batch, slim_batch
|
||||
|
||||
# Get new ancestor hierarchy nodes first
|
||||
permission_sync_context = PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
new_ancestors = self._get_new_ancestors_for_files(
|
||||
files=files_batch,
|
||||
@@ -1779,10 +1786,7 @@ class GoogleDriveConnector(
|
||||
if doc := build_slim_document(
|
||||
self.creds,
|
||||
file.drive_file,
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
permission_sync_context,
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
@@ -1822,11 +1826,12 @@ class GoogleDriveConnector(
|
||||
if files_batch:
|
||||
yield _yield_slim_batch()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
def _retrieve_all_slim_docs_impl(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
@@ -1836,13 +1841,34 @@ class GoogleDriveConnector(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=include_permissions,
|
||||
)
|
||||
logger.info("Drive perm sync: Slim doc retrieval complete")
|
||||
|
||||
logger.info("Drive slim doc retrieval complete")
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
raise
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs_impl(
|
||||
start=start, end=end, callback=callback, include_permissions=False
|
||||
)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs_impl(
|
||||
start=start, end=end, callback=callback, include_permissions=True
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._creds is None:
|
||||
|
||||
@@ -13,6 +13,10 @@ from pydantic import BaseModel
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
|
||||
tabular_file_to_sections,
|
||||
)
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from onyx.connectors.google_drive.models import GDriveMimeType
|
||||
@@ -28,15 +32,16 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
from onyx.file_processing.extract_file_text import read_docx_file
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.file_types import SPREADSHEET_MIME_TYPE
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
@@ -289,7 +294,7 @@ def _download_and_extract_sections_basic(
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
) -> list[TextSection | ImageSection | TabularSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
file_name = file["name"]
|
||||
@@ -308,7 +313,7 @@ def _download_and_extract_sections_basic(
|
||||
return []
|
||||
|
||||
# Store images for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
sections: list[TextSection | ImageSection | TabularSection] = []
|
||||
try:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=response_call(),
|
||||
@@ -323,10 +328,9 @@ def _download_and_extract_sections_basic(
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
# For Google Docs, Sheets, and Slides, export via the Drive API
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
@@ -335,6 +339,17 @@ def _download_and_extract_sections_basic(
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
if export_mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
|
||||
# Synthesize an extension on the filename
|
||||
ext = ".xlsx" if export_mime_type == SPREADSHEET_MIME_TYPE else ".csv"
|
||||
return list(
|
||||
tabular_file_to_sections(
|
||||
io.BytesIO(response),
|
||||
file_name=f"{file_name}{ext}",
|
||||
link=link,
|
||||
)
|
||||
)
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
@@ -356,9 +371,15 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
elif (
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
or is_tabular_file(file_name)
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||
return [TextSection(link=link, text=text)] if text else []
|
||||
return list(
|
||||
tabular_file_to_sections(
|
||||
io.BytesIO(response_call()),
|
||||
file_name=file_name,
|
||||
link=link,
|
||||
)
|
||||
)
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
@@ -369,7 +390,7 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
pdf_sections: list[TextSection | ImageSection | TabularSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
@@ -410,8 +431,9 @@ def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
|
||||
|
||||
|
||||
def align_basic_advanced(
|
||||
basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]
|
||||
) -> list[TextSection | ImageSection]:
|
||||
basic_sections: list[TextSection | ImageSection | TabularSection],
|
||||
adv_sections: list[TextSection],
|
||||
) -> list[TextSection | ImageSection | TabularSection]:
|
||||
"""Align the basic sections with the advanced sections.
|
||||
In particular, the basic sections contain all content of the file,
|
||||
including smart chips like dates and doc links. The advanced sections
|
||||
@@ -428,7 +450,7 @@ def align_basic_advanced(
|
||||
basic_full_text = "".join(
|
||||
[section.text for section in basic_sections if isinstance(section, TextSection)]
|
||||
)
|
||||
new_sections: list[TextSection | ImageSection] = []
|
||||
new_sections: list[TextSection | ImageSection | TabularSection] = []
|
||||
heading_start = 0
|
||||
for adv_ind in range(1, len(adv_sections)):
|
||||
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
|
||||
@@ -599,7 +621,7 @@ def _convert_drive_item_to_document(
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
"""
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
sections: list[TextSection | ImageSection | TabularSection] = []
|
||||
|
||||
# Only construct these services when needed
|
||||
def _get_drive_service() -> GoogleDriveService:
|
||||
@@ -639,7 +661,9 @@ def _convert_drive_item_to_document(
|
||||
doc_id=file.get("id", ""),
|
||||
)
|
||||
if doc_sections:
|
||||
sections = cast(list[TextSection | ImageSection], doc_sections)
|
||||
sections = cast(
|
||||
list[TextSection | ImageSection | TabularSection], doc_sections
|
||||
)
|
||||
if any(SMART_CHIP_CHAR in section.text for section in doc_sections):
|
||||
logger.debug(
|
||||
f"found smart chips in {file.get('name')}, aligning with basic sections"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
@@ -53,6 +54,21 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
|
||||
|
||||
Payloads written before the fix for serializing Google credentials into
|
||||
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
|
||||
Once every install has re-uploaded their Google credentials the legacy
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
@@ -162,12 +178,13 @@ def build_service_account_creds(
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
@@ -188,12 +205,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
return GoogleAppCredentials(**creds)
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
@@ -201,10 +218,14 @@ def upsert_google_app_cred(
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
KV_GOOGLE_DRIVE_CRED_KEY,
|
||||
app_credentials.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -220,12 +241,14 @@ def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
return GoogleServiceAccountKey(**creds)
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
@@ -234,12 +257,14 @@ def upsert_service_account_key(
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.json(),
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -123,6 +123,9 @@ class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@@ -39,6 +40,7 @@ class SectionType(str, Enum):
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
TABULAR = "tabular"
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
@@ -48,6 +50,7 @@ class Section(BaseModel):
|
||||
link: str | None = None
|
||||
text: str | None = None
|
||||
image_file_id: str | None = None
|
||||
heading: str | None = None
|
||||
|
||||
|
||||
class TextSection(Section):
|
||||
@@ -70,6 +73,18 @@ class ImageSection(Section):
|
||||
return sys.getsizeof(self.image_file_id) + sys.getsizeof(self.link)
|
||||
|
||||
|
||||
class TabularSection(Section):
|
||||
"""Section containing tabular data (csv/tsv content, or one sheet of
|
||||
an xlsx workbook rendered as CSV)."""
|
||||
|
||||
type: Literal[SectionType.TABULAR] = SectionType.TABULAR
|
||||
text: str # CSV representation in a string
|
||||
link: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
return sys.getsizeof(self.text) + sys.getsizeof(self.link)
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
"""Basic Information for the owner of a document, any of the fields can be left as None
|
||||
Display fallback goes as follows:
|
||||
@@ -171,7 +186,7 @@ class DocumentBase(BaseModel):
|
||||
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
|
||||
|
||||
id: str | None = None
|
||||
sections: list[TextSection | ImageSection]
|
||||
sections: Sequence[TextSection | ImageSection | TabularSection]
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
@@ -381,12 +396,9 @@ class IndexingDocument(Document):
|
||||
)
|
||||
else:
|
||||
section_len = sum(
|
||||
(
|
||||
len(section.text)
|
||||
if isinstance(section, TextSection) and section.text is not None
|
||||
else 0
|
||||
)
|
||||
len(section.text) if section.text is not None else 0
|
||||
for section in self.sections
|
||||
if isinstance(section, (TextSection, TabularSection))
|
||||
)
|
||||
|
||||
return title_len + section_len
|
||||
|
||||
@@ -41,6 +41,10 @@ from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.app_configs import SHAREPOINT_CONNECTOR_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
|
||||
tabular_file_to_sections,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
@@ -60,6 +64,7 @@ from onyx.connectors.models import ExternalAccess
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_access
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
@@ -586,7 +591,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
driveitem, f"Failed to download via graph api: {e}", e
|
||||
)
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
sections: list[TextSection | ImageSection | TabularSection] = []
|
||||
file_ext = get_file_ext(driveitem.name)
|
||||
|
||||
if not content_bytes:
|
||||
@@ -602,6 +607,19 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
)
|
||||
image_section.link = driveitem.web_url
|
||||
sections.append(image_section)
|
||||
elif is_tabular_file(driveitem.name):
|
||||
try:
|
||||
sections.extend(
|
||||
tabular_file_to_sections(
|
||||
file=io.BytesIO(content_bytes),
|
||||
file_name=driveitem.name,
|
||||
link=driveitem.web_url or "",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to extract tabular sections for '{driveitem.name}': {e}"
|
||||
)
|
||||
else:
|
||||
|
||||
def _store_embedded_image(img_data: bytes, img_name: str) -> None:
|
||||
|
||||
@@ -750,31 +750,3 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -335,6 +335,7 @@ def update_document_set(
|
||||
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
|
||||
)
|
||||
|
||||
document_set_row.name = document_set_update_request.name
|
||||
document_set_row.description = document_set_update_request.description
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
@@ -346,6 +347,25 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _safe_close_session(session: Session) -> None:
|
||||
"""Close a session, catching connection-closed errors during cleanup.
|
||||
|
||||
Long-running operations (e.g. multi-model LLM loops) can hold a session
|
||||
open for minutes. If the underlying connection is dropped by cloud
|
||||
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
|
||||
timeouts, etc.), the implicit rollback in Session.close() raises
|
||||
OperationalError or InterfaceError. Since the work is already complete,
|
||||
we log and move on — SQLAlchemy internally invalidates the connection
|
||||
for pool recycling.
|
||||
"""
|
||||
try:
|
||||
session.close()
|
||||
except DBAPIError:
|
||||
logger.warning(
|
||||
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -358,8 +378,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
session = Session(bind=engine, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
return
|
||||
|
||||
# Create connection with schema translation to handle querying the right schema
|
||||
@@ -367,8 +390,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
session = Session(bind=connection, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -2,8 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -30,17 +28,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
# from onyx.auth.models import UserRole
|
||||
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
|
||||
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
|
||||
# from onyx.db.engine import async_query_for_dms
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -981,104 +968,48 @@ def get_index_attempt_errors_for_cc_pair(
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActiveIndexAttemptMetric(NamedTuple):
|
||||
"""Row returned by get_active_index_attempts_for_metrics."""
|
||||
|
||||
status: IndexingStatus
|
||||
source: "DocumentSource"
|
||||
cc_pair_id: int
|
||||
cc_pair_name: str | None
|
||||
attempt_count: int
|
||||
|
||||
|
||||
def get_active_index_attempts_for_metrics(
|
||||
def get_index_attempt_errors_across_connectors(
|
||||
db_session: Session,
|
||||
) -> list[ActiveIndexAttemptMetric]:
|
||||
"""Return non-terminal index attempts grouped by status, source, and connector.
|
||||
cc_pair_id: int | None = None,
|
||||
error_type: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
unresolved_only: bool = True,
|
||||
page: int = 0,
|
||||
page_size: int = 25,
|
||||
) -> tuple[list[IndexAttemptError], int]:
|
||||
"""Query index attempt errors across all connectors with optional filters.
|
||||
|
||||
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
|
||||
Returns (errors, total_count) for pagination.
|
||||
"""
|
||||
from onyx.db.models import Connector
|
||||
stmt = select(IndexAttemptError)
|
||||
count_stmt = select(func.count()).select_from(IndexAttemptError)
|
||||
|
||||
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
func.count(),
|
||||
if cc_pair_id is not None:
|
||||
stmt = stmt.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
|
||||
count_stmt = count_stmt.where(
|
||||
IndexAttemptError.connector_credential_pair_id == cc_pair_id
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(IndexAttempt.status.notin_(terminal_statuses))
|
||||
.group_by(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return [ActiveIndexAttemptMetric(*row) for row in rows]
|
||||
|
||||
if error_type is not None:
|
||||
stmt = stmt.where(IndexAttemptError.error_type == error_type)
|
||||
count_stmt = count_stmt.where(IndexAttemptError.error_type == error_type)
|
||||
|
||||
def get_failed_attempt_counts_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
|
||||
if unresolved_only:
|
||||
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
count_stmt = count_stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
|
||||
When ``since`` is provided, only attempts created after that timestamp
|
||||
are counted. Defaults to the last 90 days to avoid unbounded historical
|
||||
aggregation.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
if start_time is not None:
|
||||
stmt = stmt.where(IndexAttemptError.time_created >= start_time)
|
||||
count_stmt = count_stmt.where(IndexAttemptError.time_created >= start_time)
|
||||
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.count(),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.FAILED)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.all()
|
||||
)
|
||||
return {cc_id: count for cc_id, count in rows}
|
||||
if end_time is not None:
|
||||
stmt = stmt.where(IndexAttemptError.time_created <= end_time)
|
||||
count_stmt = count_stmt.where(IndexAttemptError.time_created <= end_time)
|
||||
|
||||
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
def get_docs_indexed_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
|
||||
|
||||
Only counts attempts with status SUCCESS to avoid inflating counts with
|
||||
partial results from failed attempts. When ``since`` is provided, only
|
||||
attempts created after that timestamp are included.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
query = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
)
|
||||
rows = query.all()
|
||||
return {cc_id: int(total or 0) for cc_id, total in rows}
|
||||
total = db_session.scalar(count_stmt) or 0
|
||||
errors = list(db_session.scalars(stmt).all())
|
||||
return errors, total
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -83,47 +84,51 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
db_session: Session | None = None,
|
||||
) -> int:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
|
||||
Returns the id of the newly created Memory row.
|
||||
"""
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
db_session: Session | None = None,
|
||||
) -> int | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
Returns the id of the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
@@ -7,8 +7,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
@@ -22,6 +20,7 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
@@ -184,6 +183,14 @@ def generate_final_report(
|
||||
return has_reasoned
|
||||
|
||||
|
||||
def _get_research_agent_tool_id() -> int:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -193,7 +200,6 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt: str | None, # noqa: ARG001
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -717,6 +723,7 @@ def run_deep_research_llm_loop(
|
||||
simple_chat_history.append(assistant_with_tools)
|
||||
|
||||
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
|
||||
research_agent_tool_id = _get_research_agent_tool_id()
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
@@ -737,10 +744,7 @@ def run_deep_research_llm_loop(
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
tool_id=research_agent_tool_id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
|
||||
@@ -379,13 +379,25 @@ def _worksheet_to_matrix(
|
||||
worksheet: Worksheet,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Converts a singular worksheet to a matrix of values
|
||||
Converts a singular worksheet to a matrix of values.
|
||||
|
||||
Rows are padded to a uniform width. In openpyxl's read_only mode,
|
||||
iter_rows can yield rows of differing lengths (trailing empty cells
|
||||
are sometimes omitted), and downstream column cleanup assumes a
|
||||
rectangular matrix.
|
||||
"""
|
||||
rows: list[list[str]] = []
|
||||
max_len = 0
|
||||
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
|
||||
row = ["" if cell is None else str(cell) for cell in worksheet_row]
|
||||
if len(row) > max_len:
|
||||
max_len = len(row)
|
||||
rows.append(row)
|
||||
|
||||
for row in rows:
|
||||
if len(row) < max_len:
|
||||
row.extend([""] * (max_len - len(row)))
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
@@ -463,29 +475,13 @@ def _remove_empty_runs(
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str, str]]:
|
||||
"""
|
||||
Converts each sheet in the excel file to a csv condensed string.
|
||||
Returns a string and the worksheet title for each worksheet
|
||||
|
||||
# md = get_markitdown_converter()
|
||||
# stream_info = StreamInfo(
|
||||
# mimetype=SPREADSHEET_MIME_TYPE, filename=file_name or None, extension=".xlsx"
|
||||
# )
|
||||
# try:
|
||||
# workbook = md.convert(to_bytesio(file), stream_info=stream_info)
|
||||
# except (
|
||||
# BadZipFile,
|
||||
# ValueError,
|
||||
# FileConversionException,
|
||||
# UnsupportedFormatException,
|
||||
# ) as e:
|
||||
# error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
# if file_name.startswith("~"):
|
||||
# logger.debug(error_str + " (this is expected for files with ~)")
|
||||
# else:
|
||||
# logger.warning(error_str)
|
||||
# return ""
|
||||
# return workbook.markdown
|
||||
Returns a list of (csv_text, sheet)
|
||||
"""
|
||||
try:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
except BadZipFile as e:
|
||||
@@ -494,23 +490,30 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
logger.debug(error_str + " (this is expected for files with ~)")
|
||||
else:
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
return []
|
||||
except Exception as e:
|
||||
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
|
||||
logger.error(
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
return []
|
||||
raise
|
||||
|
||||
text_content = []
|
||||
sheets: list[tuple[str, str]] = []
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
text_content.append(buf.getvalue().rstrip("\n"))
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
csv_text = buf.getvalue().rstrip("\n")
|
||||
if csv_text.strip():
|
||||
sheets.append((csv_text, sheet.title))
|
||||
return sheets
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
sheets = xlsx_sheet_extraction(file, file_name)
|
||||
return TEXT_SECTION_SEPARATOR.join(csv_text for csv_text, _title in sheets)
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.indexing.chunking.image_section_chunker import ImageChunker
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.tabular_section_chunker import TabularChunker
|
||||
from onyx.indexing.chunking.text_section_chunker import TextChunker
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
@@ -38,6 +39,7 @@ class DocumentChunker:
|
||||
chunk_splitter=chunk_splitter,
|
||||
),
|
||||
SectionType.IMAGE: ImageChunker(),
|
||||
SectionType.TABULAR: TabularChunker(tokenizer=tokenizer),
|
||||
}
|
||||
|
||||
def chunk(
|
||||
@@ -99,7 +101,9 @@ class DocumentChunker:
|
||||
payloads.extend(result.payloads)
|
||||
accumulator = result.accumulator
|
||||
|
||||
# Final flush — any leftover buffered text becomes one last payload.
|
||||
payloads.extend(accumulator.flush_to_list())
|
||||
|
||||
return payloads
|
||||
|
||||
def _select_chunker(self, section: Section) -> SectionChunker:
|
||||
|
||||
272
backend/onyx/indexing/chunking/tabular_section_chunker.py
Normal file
272
backend/onyx/indexing/chunking/tabular_section_chunker.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import split_text_by_tokens
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
COLUMNS_MARKER = "Columns:"
|
||||
FIELD_VALUE_SEPARATOR = ", "
|
||||
ROW_JOIN = "\n"
|
||||
NEWLINE_TOKENS = 1
|
||||
|
||||
|
||||
class _ParsedRow(BaseModel):
|
||||
header: list[str]
|
||||
row: list[str]
|
||||
|
||||
|
||||
class _TokenizedText(BaseModel):
|
||||
text: str
|
||||
token_count: int
|
||||
|
||||
|
||||
def format_row(header: list[str], row: list[str]) -> str:
|
||||
"""
|
||||
A header-row combination is formatted like this:
|
||||
field1=value1, field2=value2, field3=value3
|
||||
"""
|
||||
pairs = _row_to_pairs(header, row)
|
||||
formatted = FIELD_VALUE_SEPARATOR.join(f"{h}={v}" for h, v in pairs)
|
||||
return formatted
|
||||
|
||||
|
||||
def format_columns_header(headers: list[str]) -> str:
|
||||
"""
|
||||
Format the column header line. Underscored headers get a
|
||||
space-substituted friendly alias in parens.
|
||||
Example:
|
||||
headers = ["id", "MTTR_hours"]
|
||||
=> "Columns: id, MTTR_hours (MTTR hours)"
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for header in headers:
|
||||
friendly = header
|
||||
if "_" in header:
|
||||
friendly = f'{header} ({header.replace("_", " ")})'
|
||||
parts.append(friendly)
|
||||
return f"{COLUMNS_MARKER} " + FIELD_VALUE_SEPARATOR.join(parts)
|
||||
|
||||
|
||||
def parse_section(section: Section) -> list[_ParsedRow]:
|
||||
"""Parse CSV into headers + rows. First non-empty row is the header;
|
||||
blank rows are skipped."""
|
||||
section_text = section.text or ""
|
||||
if not section_text.strip():
|
||||
return []
|
||||
|
||||
reader = csv.reader(io.StringIO(section_text))
|
||||
non_empty_rows = [row for row in reader if any(cell.strip() for cell in row)]
|
||||
|
||||
if not non_empty_rows:
|
||||
return []
|
||||
|
||||
header, *data_rows = non_empty_rows
|
||||
return [_ParsedRow(header=header, row=row) for row in data_rows]
|
||||
|
||||
|
||||
def _row_to_pairs(headers: list[str], row: list[str]) -> list[tuple[str, str]]:
|
||||
return [(h, v) for h, v in zip(headers, row) if v.strip()]
|
||||
|
||||
|
||||
def pack_chunk(chunk: str, new_row: str) -> str:
|
||||
return chunk + "\n" + new_row
|
||||
|
||||
|
||||
def _split_row_by_pairs(
|
||||
pairs: list[tuple[str, str]],
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[_TokenizedText]:
|
||||
"""Greedily pack pairs into max-sized pieces. Any single pair that
|
||||
itself exceeds ``max_tokens`` is token-split at id boundaries.
|
||||
No headers."""
|
||||
separator_tokens = count_tokens(FIELD_VALUE_SEPARATOR, tokenizer)
|
||||
pieces: list[_TokenizedText] = []
|
||||
current_parts: list[str] = []
|
||||
current_tokens = 0
|
||||
|
||||
for pair in pairs:
|
||||
pair_str = f"{pair[0]}={pair[1]}"
|
||||
pair_tokens = count_tokens(pair_str, tokenizer)
|
||||
increment = pair_tokens if not current_parts else separator_tokens + pair_tokens
|
||||
|
||||
if current_tokens + increment <= max_tokens:
|
||||
current_parts.append(pair_str)
|
||||
current_tokens += increment
|
||||
continue
|
||||
|
||||
if current_parts:
|
||||
pieces.append(
|
||||
_TokenizedText(
|
||||
text=FIELD_VALUE_SEPARATOR.join(current_parts),
|
||||
token_count=current_tokens,
|
||||
)
|
||||
)
|
||||
current_parts = []
|
||||
current_tokens = 0
|
||||
|
||||
if pair_tokens > max_tokens:
|
||||
for split_text in split_text_by_tokens(pair_str, tokenizer, max_tokens):
|
||||
pieces.append(
|
||||
_TokenizedText(
|
||||
text=split_text,
|
||||
token_count=count_tokens(split_text, tokenizer),
|
||||
)
|
||||
)
|
||||
else:
|
||||
current_parts = [pair_str]
|
||||
current_tokens = pair_tokens
|
||||
|
||||
if current_parts:
|
||||
pieces.append(
|
||||
_TokenizedText(
|
||||
text=FIELD_VALUE_SEPARATOR.join(current_parts),
|
||||
token_count=current_tokens,
|
||||
)
|
||||
)
|
||||
return pieces
|
||||
|
||||
|
||||
def _build_chunk_from_scratch(
|
||||
pairs: list[tuple[str, str]],
|
||||
formatted_row: str,
|
||||
row_tokens: int,
|
||||
column_header: str,
|
||||
column_header_tokens: int,
|
||||
sheet_header: str,
|
||||
sheet_header_tokens: int,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[_TokenizedText]:
|
||||
# 1. Row alone is too large — split by pairs, no headers.
|
||||
if row_tokens > max_tokens:
|
||||
return _split_row_by_pairs(pairs, tokenizer, max_tokens)
|
||||
|
||||
chunk = formatted_row
|
||||
chunk_tokens = row_tokens
|
||||
|
||||
# 2. Attempt to add column header
|
||||
candidate_tokens = column_header_tokens + NEWLINE_TOKENS + chunk_tokens
|
||||
if candidate_tokens <= max_tokens:
|
||||
chunk = column_header + ROW_JOIN + chunk
|
||||
chunk_tokens = candidate_tokens
|
||||
|
||||
# 3. Attempt to add sheet header
|
||||
if sheet_header:
|
||||
candidate_tokens = sheet_header_tokens + NEWLINE_TOKENS + chunk_tokens
|
||||
if candidate_tokens <= max_tokens:
|
||||
chunk = sheet_header + ROW_JOIN + chunk
|
||||
chunk_tokens = candidate_tokens
|
||||
|
||||
return [_TokenizedText(text=chunk, token_count=chunk_tokens)]
|
||||
|
||||
|
||||
def parse_to_chunks(
|
||||
rows: Iterable[_ParsedRow],
|
||||
sheet_header: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[str]:
|
||||
rows_list = list(rows)
|
||||
if not rows_list:
|
||||
return []
|
||||
|
||||
column_header = format_columns_header(rows_list[0].header)
|
||||
column_header_tokens = count_tokens(column_header, tokenizer)
|
||||
sheet_header_tokens = count_tokens(sheet_header, tokenizer) if sheet_header else 0
|
||||
|
||||
chunks: list[str] = []
|
||||
current_chunk = ""
|
||||
current_chunk_tokens = 0
|
||||
|
||||
for row in rows_list:
|
||||
pairs: list[tuple[str, str]] = _row_to_pairs(row.header, row.row)
|
||||
formatted = format_row(row.header, row.row)
|
||||
row_tokens = count_tokens(formatted, tokenizer)
|
||||
|
||||
if current_chunk:
|
||||
# Attempt to pack it in (additive approximation)
|
||||
if current_chunk_tokens + NEWLINE_TOKENS + row_tokens <= max_tokens:
|
||||
current_chunk = pack_chunk(current_chunk, formatted)
|
||||
current_chunk_tokens += NEWLINE_TOKENS + row_tokens
|
||||
continue
|
||||
# Doesn't fit — flush and start new
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = ""
|
||||
current_chunk_tokens = 0
|
||||
|
||||
# Build chunk from scratch
|
||||
for piece in _build_chunk_from_scratch(
|
||||
pairs=pairs,
|
||||
formatted_row=formatted,
|
||||
row_tokens=row_tokens,
|
||||
column_header=column_header,
|
||||
column_header_tokens=column_header_tokens,
|
||||
sheet_header=sheet_header,
|
||||
sheet_header_tokens=sheet_header_tokens,
|
||||
tokenizer=tokenizer,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = piece.text
|
||||
current_chunk_tokens = piece.token_count
|
||||
|
||||
# Flush remaining
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class TabularChunker(SectionChunker):
|
||||
def __init__(self, tokenizer: BaseTokenizer) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
payloads = accumulator.flush_to_list()
|
||||
|
||||
parsed_rows = parse_section(section)
|
||||
if not parsed_rows:
|
||||
logger.warning(
|
||||
f"TabularChunker: skipping unparseable section (link={section.link})"
|
||||
)
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads, accumulator=AccumulatorState()
|
||||
)
|
||||
|
||||
sheet_header = section.heading or ""
|
||||
chunk_texts = parse_to_chunks(
|
||||
rows=parsed_rows,
|
||||
sheet_header=sheet_header,
|
||||
tokenizer=self.tokenizer,
|
||||
max_tokens=content_token_limit,
|
||||
)
|
||||
|
||||
for i, text in enumerate(chunk_texts):
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=text,
|
||||
links={0: section.link or ""},
|
||||
is_continuation=(i > 0),
|
||||
)
|
||||
)
|
||||
return SectionChunkerOutput(payloads=payloads, accumulator=AccumulatorState())
|
||||
@@ -10,6 +10,7 @@ from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import split_text_by_tokens
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
@@ -90,8 +91,8 @@ class TextChunker(SectionChunker):
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and count_tokens(split_text, self.tokenizer) > content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
smaller_chunks = split_text_by_tokens(
|
||||
split_text, self.tokenizer, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
payloads.append(
|
||||
@@ -114,16 +115,3 @@ class TextChunker(SectionChunker):
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
|
||||
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
total_tokens = len(tokens)
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
|
||||
@@ -3,6 +3,8 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
@@ -291,6 +293,13 @@ def embed_chunks_with_failure_handling(
|
||||
)
|
||||
embedded_chunks.extend(doc_embedded_chunks)
|
||||
except Exception as e:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "embedding")
|
||||
scope.set_tag("doc_id", doc_id)
|
||||
if tenant_id:
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
scope.fingerprint = ["embedding-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Protocol
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -332,6 +333,13 @@ def index_doc_batch_with_handler(
|
||||
except Exception as e:
|
||||
# don't log the batch directly, it's too much text
|
||||
document_ids = [doc.id for doc in document_batch]
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "indexing_pipeline")
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
scope.set_tag("batch_size", str(len(document_batch)))
|
||||
scope.set_extra("document_ids", document_ids)
|
||||
scope.fingerprint = ["indexing-pipeline-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(f"Failed to index document batch: {document_ids}")
|
||||
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
@@ -543,7 +551,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
processed_sections=[
|
||||
Section(
|
||||
type=section.type,
|
||||
text=section.text if isinstance(section, TextSection) else "",
|
||||
text="" if isinstance(section, ImageSection) else section.text,
|
||||
link=section.link,
|
||||
image_file_id=(
|
||||
section.image_file_id
|
||||
@@ -609,7 +617,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
processed_sections.append(processed_section)
|
||||
|
||||
# For TextSection, create a base Section with text and link
|
||||
elif isinstance(section, TextSection):
|
||||
else:
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
text=section.text or "", # Ensure text is always a string, not None
|
||||
|
||||
@@ -6,6 +6,7 @@ from itertools import chain
|
||||
from itertools import groupby
|
||||
|
||||
import httpx
|
||||
import sentry_sdk
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
@@ -88,6 +89,12 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "vector_db_write")
|
||||
scope.set_tag("doc_id", doc_id)
|
||||
scope.set_tag("tenant_id", index_batch_params.tenant_id)
|
||||
scope.fingerprint = ["vector-db-write-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(
|
||||
f"Failed to write document chunks for '{doc_id}' to vector db"
|
||||
)
|
||||
|
||||
@@ -434,11 +434,14 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
lifespan=lifespan_override or lifespan,
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -201,6 +201,33 @@ def count_tokens(
|
||||
return total
|
||||
|
||||
|
||||
def split_text_by_tokens(
|
||||
text: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[str]:
|
||||
"""Split ``text`` into pieces of ≤ ``max_tokens`` tokens each, via
|
||||
encode/decode at token-id boundaries.
|
||||
|
||||
Note: the returned pieces are not strictly guaranteed to re-tokenize to
|
||||
≤ max_tokens. BPE merges at window boundaries may drift by a few tokens,
|
||||
and cuts landing mid-multi-byte-UTF-8-character produce replacement
|
||||
characters on decode. Good enough for "best-effort" splitting of
|
||||
oversized content, not for hard limit enforcement.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
token_ids: list[int] = []
|
||||
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
|
||||
token_ids.extend(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
|
||||
|
||||
return [
|
||||
tokenizer.decode(token_ids[start : start + max_tokens])
|
||||
for start in range(0, len(token_ids), max_tokens)
|
||||
]
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
|
||||
@@ -125,6 +125,11 @@ class TenantRedis(redis.Redis):
|
||||
"sadd",
|
||||
"srem",
|
||||
"scard",
|
||||
"zadd",
|
||||
"zrangebyscore",
|
||||
"zremrangebyscore",
|
||||
"zscore",
|
||||
"zcard",
|
||||
"hexists",
|
||||
"hset",
|
||||
"hdel",
|
||||
|
||||
104
backend/onyx/redis/redis_tenant_work_gating.py
Normal file
104
backend/onyx/redis/redis_tenant_work_gating.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Redis helpers for the tenant work-gating feature.
|
||||
|
||||
One sorted set `active_tenants` under the cloud Redis tenant tracks the last
|
||||
time each tenant was observed doing work. The fanout generator reads the set
|
||||
(filtered to entries within a TTL window) and skips tenants that haven't been
|
||||
active recently.
|
||||
|
||||
All public functions no-op in single-tenant mode (`MULTI_TENANT=False`).
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Unprefixed key. `TenantRedis._prefixed` prepends `cloud:` at call time so
|
||||
# the full rendered key is `cloud:active_tenants`.
|
||||
_SET_KEY = "active_tenants"
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _client() -> Redis:
|
||||
return get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
|
||||
def mark_tenant_active(tenant_id: str) -> None:
|
||||
"""Record that `tenant_id` was just observed doing work (ZADD with the
|
||||
current timestamp as the score). Best-effort — a Redis failure is logged
|
||||
and swallowed so it never breaks a writer path.
|
||||
|
||||
Call sites:
|
||||
- Top of each gated beat-task consumer when its "is there work?" query
|
||||
returns a non-empty result.
|
||||
- cc_pair create lifecycle hook.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return
|
||||
|
||||
try:
|
||||
# `mapping={member: score}` syntax; ZADD overwrites the score on
|
||||
# existing members, which is exactly the refresh semantics we want.
|
||||
_client().zadd(_SET_KEY, mapping={tenant_id: _now_ms()})
|
||||
except Exception:
|
||||
logger.exception(f"mark_tenant_active failed: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
def get_active_tenants(ttl_seconds: int) -> set[str] | None:
|
||||
"""Return tenants whose last-seen timestamp is within `ttl_seconds` of
|
||||
now.
|
||||
|
||||
Return values:
|
||||
- `set[str]` (possibly empty) — Redis read succeeded. Empty set means
|
||||
no tenants are currently marked active; callers should *skip* all
|
||||
tenants if the gate is enforcing.
|
||||
- `None` — Redis read failed *or* we are in single-tenant mode. Callers
|
||||
should fail open (dispatch to every tenant this cycle). Distinguishing
|
||||
failure from "genuinely empty" prevents a Redis outage from silently
|
||||
starving every tenant on every enforced cycle.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
|
||||
cutoff_ms = _now_ms() - (ttl_seconds * 1000)
|
||||
try:
|
||||
raw = cast(
|
||||
list[bytes],
|
||||
_client().zrangebyscore(_SET_KEY, min=cutoff_ms, max="+inf"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("get_active_tenants failed")
|
||||
return None
|
||||
|
||||
return {m.decode() if isinstance(m, bytes) else m for m in raw}
|
||||
|
||||
|
||||
def cleanup_expired(ttl_seconds: int) -> int:
|
||||
"""Remove members older than `ttl_seconds` from the set. Optional
|
||||
memory-hygiene helper — correctness does not depend on calling this, but
|
||||
without it the set grows unboundedly as old tenants accumulate. Returns
|
||||
the number of members removed."""
|
||||
if not MULTI_TENANT:
|
||||
return 0
|
||||
|
||||
cutoff_ms = _now_ms() - (ttl_seconds * 1000)
|
||||
try:
|
||||
removed = cast(
|
||||
int,
|
||||
_client().zremrangebyscore(_SET_KEY, min="-inf", max=f"({cutoff_ms}"),
|
||||
)
|
||||
return removed
|
||||
except Exception:
|
||||
logger.exception("cleanup_expired failed")
|
||||
return 0
|
||||
@@ -63,6 +63,7 @@ class DocumentSetCreationRequest(BaseModel):
|
||||
|
||||
class DocumentSetUpdateRequest(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
cc_pair_ids: list[int]
|
||||
is_public: bool
|
||||
|
||||
@@ -11,6 +11,9 @@ from onyx.db.notification import dismiss_notification
|
||||
from onyx.db.notification import get_notification_by_id
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
|
||||
from onyx.server.features.notifications.utils import (
|
||||
ensure_permissions_migration_notification,
|
||||
)
|
||||
from onyx.server.features.release_notes.utils import (
|
||||
ensure_release_notes_fresh_and_notify,
|
||||
)
|
||||
@@ -49,6 +52,13 @@ def get_notifications_api(
|
||||
except Exception:
|
||||
logger.exception("Failed to check for release notes in notifications endpoint")
|
||||
|
||||
try:
|
||||
ensure_permissions_migration_notification(user, db_session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
|
||||
)
|
||||
|
||||
notifications = [
|
||||
NotificationModel.from_model(notif)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=True)
|
||||
|
||||
21
backend/onyx/server/features/notifications/utils.py
Normal file
21
backend/onyx/server/features/notifications/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
|
||||
|
||||
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
|
||||
# Feature id "permissions_migration_v1" must not change after shipping —
|
||||
# it is the dedup key on (user_id, notif_type, additional_data).
|
||||
create_notification(
|
||||
user_id=user.id,
|
||||
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
|
||||
db_session=db_session,
|
||||
title="Permissions are changing in Onyx",
|
||||
description="Roles are moving to group-based permissions. Click for details.",
|
||||
additional_data={
|
||||
"feature": "permissions_migration_v1",
|
||||
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
|
||||
},
|
||||
)
|
||||
@@ -185,6 +185,10 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
for doc_set in persona.document_sets:
|
||||
for cc_pair in doc_set.connector_credential_pairs:
|
||||
sources.add(cc_pair.connector.source)
|
||||
for fed_ds in doc_set.federated_connectors:
|
||||
non_fed = fed_ds.federated_connector.source.to_non_federated_source()
|
||||
if non_fed is not None:
|
||||
sources.add(non_fed)
|
||||
|
||||
# Sources from hierarchy nodes
|
||||
for node in persona.hierarchy_nodes:
|
||||
@@ -195,6 +199,9 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if doc.parent_hierarchy_node:
|
||||
sources.add(doc.parent_hierarchy_node.source)
|
||||
|
||||
if persona.user_files:
|
||||
sources.add(DocumentSource.USER_FILE)
|
||||
|
||||
return MinimalPersonaSnapshot(
|
||||
# Core fields actually used by ChatPage
|
||||
id=persona.id,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
|
||||
@@ -28,6 +29,7 @@ from onyx.db.feedback import fetch_docs_ranked_by_boost_for_user
|
||||
from onyx.db.feedback import update_document_boost_for_user
|
||||
from onyx.db.feedback import update_document_hidden_for_user
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_across_connectors
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -35,6 +37,7 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import BoostDoc
|
||||
from onyx.server.manage.models import BoostUpdateRequest
|
||||
from onyx.server.manage.models import HiddenUpdateRequest
|
||||
@@ -206,3 +209,40 @@ def create_deletion_attempt_for_connector_id(
|
||||
file_store = get_default_file_store()
|
||||
for file_id in connector.connector_specific_config.get("file_locations", []):
|
||||
file_store.delete_file(file_id)
|
||||
|
||||
|
||||
@router.get("/admin/indexing/failed-documents")
|
||||
def get_failed_documents(
|
||||
cc_pair_id: int | None = None,
|
||||
error_type: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
include_resolved: bool = False,
|
||||
page_num: int = 0,
|
||||
page_size: int = 25,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[IndexAttemptErrorPydantic]:
|
||||
"""Get indexing errors across all connectors with optional filters.
|
||||
|
||||
Provides a cross-connector view of document indexing failures.
|
||||
Defaults to last 30 days if no start_time is provided to avoid
|
||||
unbounded count queries.
|
||||
"""
|
||||
if start_time is None:
|
||||
start_time = datetime.now(tz=timezone.utc) - timedelta(days=30)
|
||||
|
||||
errors, total = get_index_attempt_errors_across_connectors(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
error_type=error_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
unresolved_only=not include_resolved,
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
)
|
||||
return PaginatedReturn(
|
||||
items=[IndexAttemptErrorPydantic.from_model(e) for e in errors],
|
||||
total_items=total,
|
||||
)
|
||||
|
||||
@@ -183,6 +183,9 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
"qwen2.5:7b" → "Qwen 2.5 7B"
|
||||
"mistral:latest" → "Mistral"
|
||||
"deepseek-r1:14b" → "DeepSeek R1 14B"
|
||||
"gemma4:e4b" → "Gemma 4 E4B"
|
||||
"deepseek-v3.1:671b-cloud" → "DeepSeek V3.1 671B Cloud"
|
||||
"qwen3-vl:235b-instruct-cloud" → "Qwen 3-vl 235B Instruct Cloud"
|
||||
"""
|
||||
# Split into base name and tag
|
||||
if ":" in model_name:
|
||||
@@ -209,13 +212,24 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
# Default: Title case with dashes converted to spaces
|
||||
display_name = base.replace("-", " ").title()
|
||||
|
||||
# Process tag to extract size info (skip "latest")
|
||||
# Process tag (skip "latest")
|
||||
if tag and tag.lower() != "latest":
|
||||
# Extract size like "7b", "70b", "14b"
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])", tag)
|
||||
# Check for size prefix like "7b", "70b", optionally followed by modifiers
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])(-.+)?$", tag)
|
||||
if size_match:
|
||||
size = size_match.group(1).upper()
|
||||
display_name = f"{display_name} {size}"
|
||||
remainder = size_match.group(2)
|
||||
if remainder:
|
||||
# Format modifiers like "-cloud", "-instruct-cloud"
|
||||
modifiers = " ".join(
|
||||
p.title() for p in remainder.strip("-").split("-") if p
|
||||
)
|
||||
display_name = f"{display_name} {size} {modifiers}"
|
||||
else:
|
||||
display_name = f"{display_name} {size}"
|
||||
else:
|
||||
# Non-size tags like "e4b", "q4_0", "fp16", "cloud"
|
||||
display_name = f"{display_name} {tag.upper()}"
|
||||
|
||||
return display_name
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import json
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
@@ -113,28 +114,47 @@ async def transcribe_audio(
|
||||
) from exc
|
||||
|
||||
|
||||
def _extract_provider_error(exc: Exception) -> str:
|
||||
"""Extract a human-readable message from a provider exception.
|
||||
|
||||
Provider errors often embed JSON from upstream APIs (e.g. ElevenLabs).
|
||||
This tries to parse a readable ``message`` field out of common JSON
|
||||
error shapes; falls back to ``str(exc)`` if nothing better is found.
|
||||
"""
|
||||
raw = str(exc)
|
||||
try:
|
||||
# Many providers embed JSON after a prefix like "ElevenLabs TTS failed: {...}"
|
||||
json_start = raw.find("{")
|
||||
if json_start == -1:
|
||||
return raw
|
||||
parsed = json.loads(raw[json_start:])
|
||||
# Shape: {"detail": {"message": "..."}} (ElevenLabs)
|
||||
detail = parsed.get("detail", parsed)
|
||||
if isinstance(detail, dict):
|
||||
return detail.get("message") or detail.get("error") or raw
|
||||
if isinstance(detail, str):
|
||||
return detail
|
||||
except (json.JSONDecodeError, AttributeError, TypeError):
|
||||
pass
|
||||
return raw
|
||||
|
||||
|
||||
class SynthesizeRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1)
|
||||
voice: str | None = None
|
||||
speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
body: SynthesizeRequest,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
"""Synthesize text to speech using the default TTS provider."""
|
||||
text = body.text
|
||||
voice = body.voice
|
||||
speed = body.speed
|
||||
logger.info(f"TTS request: text length={len(text)}, voice={voice}, speed={speed}")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
@@ -177,31 +197,36 @@ async def synthesize_speech(
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
# Pull the first chunk before returning the StreamingResponse. If the
|
||||
# provider rejects the request (e.g. text too long), the error surfaces
|
||||
# as a proper HTTP error instead of a broken audio stream.
|
||||
stream_iter = provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
)
|
||||
try:
|
||||
first_chunk = await stream_iter.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "TTS provider returned no audio")
|
||||
except Exception as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, _extract_provider_error(exc)
|
||||
) from exc
|
||||
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
yield first_chunk
|
||||
chunk_count = 1
|
||||
async for chunk in stream_iter:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
active task gauge, task duration histograms, queue wait time histograms,
|
||||
and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
@@ -71,6 +72,32 @@ TASK_REJECTED = Counter(
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_QUEUE_WAIT = Histogram(
|
||||
"onyx_celery_task_queue_wait_seconds",
|
||||
"Time a Celery task spent waiting in the queue before execution started",
|
||||
["task_name", "queue"],
|
||||
buckets=[
|
||||
0.1,
|
||||
0.5,
|
||||
1,
|
||||
5,
|
||||
30,
|
||||
60,
|
||||
300,
|
||||
600,
|
||||
1800,
|
||||
3600,
|
||||
7200,
|
||||
14400,
|
||||
28800,
|
||||
43200,
|
||||
86400,
|
||||
172800,
|
||||
432000,
|
||||
864000,
|
||||
],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
@@ -133,6 +160,13 @@ def on_celery_task_prerun(
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
|
||||
headers = getattr(task.request, "headers", None) or {}
|
||||
enqueued_at = headers.get("enqueued_at")
|
||||
if isinstance(enqueued_at, (int, float)):
|
||||
TASK_QUEUE_WAIT.labels(**labels).observe(
|
||||
max(0.0, time.time() - enqueued_at)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
123
backend/onyx/server/metrics/connector_health_metrics.py
Normal file
123
backend/onyx/server/metrics/connector_health_metrics.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Prometheus metrics for connector health and index attempts.
|
||||
|
||||
Emitted by docfetching and docprocessing workers when connector or
|
||||
index attempt state changes. All functions silently catch exceptions
|
||||
to avoid disrupting the caller's business logic.
|
||||
|
||||
Gauge metrics (error state, last success timestamp) are per-process.
|
||||
With multiple worker pods, use max() aggregation in PromQL to get the
|
||||
correct value across instances, e.g.:
|
||||
max by (cc_pair_id, connector_name) (onyx_connector_in_error_state)
|
||||
|
||||
Unlike the per-task counters in indexing_task_metrics.py, these metrics
|
||||
include connector_name because their cardinality is bounded by the number
|
||||
of connectors (one series per connector), not by the number of task
|
||||
executions.
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_CONNECTOR_LABELS = ["tenant_id", "source", "cc_pair_id", "connector_name"]
|
||||
|
||||
# --- Index attempt lifecycle ---
|
||||
|
||||
INDEX_ATTEMPT_STATUS = Counter(
|
||||
"onyx_index_attempt_transitions_total",
|
||||
"Index attempt status transitions",
|
||||
[*_CONNECTOR_LABELS, "status"],
|
||||
)
|
||||
|
||||
# --- Connector health ---
|
||||
|
||||
CONNECTOR_IN_ERROR_STATE = Gauge(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
_CONNECTOR_LABELS,
|
||||
)
|
||||
|
||||
CONNECTOR_LAST_SUCCESS_TIMESTAMP = Gauge(
|
||||
"onyx_connector_last_success_timestamp_seconds",
|
||||
"Unix timestamp of last successful indexing for this connector",
|
||||
_CONNECTOR_LABELS,
|
||||
)
|
||||
|
||||
CONNECTOR_DOCS_INDEXED = Counter(
|
||||
"onyx_connector_docs_indexed_total",
|
||||
"Total documents indexed per connector (monotonic)",
|
||||
_CONNECTOR_LABELS,
|
||||
)
|
||||
|
||||
CONNECTOR_INDEXING_ERRORS = Counter(
|
||||
"onyx_connector_indexing_errors_total",
|
||||
"Total failed index attempts per connector (monotonic)",
|
||||
_CONNECTOR_LABELS,
|
||||
)
|
||||
|
||||
|
||||
def on_index_attempt_status_change(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
connector_name: str,
|
||||
status: str,
|
||||
) -> None:
|
||||
"""Called on any index attempt status transition."""
|
||||
try:
|
||||
labels = {
|
||||
"tenant_id": tenant_id,
|
||||
"source": source,
|
||||
"cc_pair_id": str(cc_pair_id),
|
||||
"connector_name": connector_name,
|
||||
}
|
||||
INDEX_ATTEMPT_STATUS.labels(**labels, status=status).inc()
|
||||
if status == "failed":
|
||||
CONNECTOR_INDEXING_ERRORS.labels(**labels).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record index attempt status metric", exc_info=True)
|
||||
|
||||
|
||||
def on_connector_error_state_change(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
connector_name: str,
|
||||
in_error: bool,
|
||||
) -> None:
|
||||
"""Called when a connector's in_repeated_error_state changes."""
|
||||
try:
|
||||
CONNECTOR_IN_ERROR_STATE.labels(
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
connector_name=connector_name,
|
||||
).set(1.0 if in_error else 0.0)
|
||||
except Exception:
|
||||
logger.debug("Failed to record connector error state metric", exc_info=True)
|
||||
|
||||
|
||||
def on_connector_indexing_success(
|
||||
tenant_id: str,
|
||||
source: str,
|
||||
cc_pair_id: int,
|
||||
connector_name: str,
|
||||
docs_indexed: int,
|
||||
success_timestamp: float,
|
||||
) -> None:
|
||||
"""Called when an indexing run completes successfully."""
|
||||
try:
|
||||
labels = {
|
||||
"tenant_id": tenant_id,
|
||||
"source": source,
|
||||
"cc_pair_id": str(cc_pair_id),
|
||||
"connector_name": connector_name,
|
||||
}
|
||||
CONNECTOR_LAST_SUCCESS_TIMESTAMP.labels(**labels).set(success_timestamp)
|
||||
if docs_indexed > 0:
|
||||
CONNECTOR_DOCS_INDEXED.labels(**labels).inc(docs_indexed)
|
||||
except Exception:
|
||||
logger.debug("Failed to record connector success metric", exc_info=True)
|
||||
104
backend/onyx/server/metrics/deletion_metrics.py
Normal file
104
backend/onyx/server/metrics/deletion_metrics.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Connector-deletion-specific Prometheus metrics.
|
||||
|
||||
Tracks the deletion lifecycle:
|
||||
1. Deletions started (taskset generated)
|
||||
2. Deletions completed (success or failure)
|
||||
3. Taskset duration (from taskset generation to completion or failure).
|
||||
Note: this measures the most recent taskset execution, NOT wall-clock
|
||||
time since the user triggered the deletion. When deletion is blocked by
|
||||
indexing/pruning/permissions, the fence is cleared and a fresh taskset
|
||||
is generated on each retry, resetting this timer.
|
||||
4. Deletion blocked by dependencies (indexing, pruning, permissions, etc.)
|
||||
5. Fence resets (stuck deletion recovery)
|
||||
|
||||
All metrics are labeled by tenant_id. cc_pair_id is intentionally excluded
|
||||
to avoid unbounded cardinality.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.deletion_metrics import (
|
||||
inc_deletion_started,
|
||||
inc_deletion_completed,
|
||||
observe_deletion_taskset_duration,
|
||||
inc_deletion_blocked,
|
||||
inc_deletion_fence_reset,
|
||||
)
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DELETION_STARTED = Counter(
|
||||
"onyx_deletion_started_total",
|
||||
"Connector deletions initiated (taskset generated)",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
DELETION_COMPLETED = Counter(
|
||||
"onyx_deletion_completed_total",
|
||||
"Connector deletions completed",
|
||||
["tenant_id", "outcome"],
|
||||
)
|
||||
|
||||
DELETION_TASKSET_DURATION = Histogram(
|
||||
"onyx_deletion_taskset_duration_seconds",
|
||||
"Duration of a connector deletion taskset, from taskset generation "
|
||||
"to completion or failure. Does not include time spent blocked on "
|
||||
"indexing/pruning/permissions before the taskset was generated.",
|
||||
["tenant_id", "outcome"],
|
||||
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600, 7200, 21600],
|
||||
)
|
||||
|
||||
DELETION_BLOCKED = Counter(
|
||||
"onyx_deletion_blocked_total",
|
||||
"Times deletion was blocked by a dependency",
|
||||
["tenant_id", "blocker"],
|
||||
)
|
||||
|
||||
DELETION_FENCE_RESET = Counter(
|
||||
"onyx_deletion_fence_reset_total",
|
||||
"Deletion fences reset due to missing celery tasks",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
|
||||
def inc_deletion_started(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_STARTED.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion started", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_completed(tenant_id: str, outcome: str) -> None:
|
||||
try:
|
||||
DELETION_COMPLETED.labels(tenant_id=tenant_id, outcome=outcome).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion completed", exc_info=True)
|
||||
|
||||
|
||||
def observe_deletion_taskset_duration(
|
||||
tenant_id: str, outcome: str, duration_seconds: float
|
||||
) -> None:
|
||||
try:
|
||||
DELETION_TASKSET_DURATION.labels(tenant_id=tenant_id, outcome=outcome).observe(
|
||||
duration_seconds
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion taskset duration", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_blocked(tenant_id: str, blocker: str) -> None:
|
||||
try:
|
||||
DELETION_BLOCKED.labels(tenant_id=tenant_id, blocker=blocker).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion blocked", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_fence_reset(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_FENCE_RESET.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion fence reset", exc_info=True)
|
||||
@@ -1,25 +1,30 @@
|
||||
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
|
||||
"""Prometheus collectors for Celery queue depths and infrastructure health.
|
||||
|
||||
These collectors query Redis and Postgres at scrape time (the Collector pattern),
|
||||
These collectors query Redis at scrape time (the Collector pattern),
|
||||
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
|
||||
monitoring celery worker which already has Redis and DB access.
|
||||
monitoring celery worker which already has Redis access.
|
||||
|
||||
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
|
||||
To avoid hammering Redis on every 15s scrape, results are cached with
|
||||
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
|
||||
stale, which is fine for monitoring dashboards.
|
||||
|
||||
Note: connector health and index attempt metrics are push-based (emitted by
|
||||
workers at state-change time) and live in connector_health_metrics.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.registry import Collector
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -31,6 +36,11 @@ logger = setup_logger()
|
||||
# the previous result without re-querying Redis/Postgres.
|
||||
_DEFAULT_CACHE_TTL = 30.0
|
||||
|
||||
# Maximum time (seconds) a single _collect_fresh() call may take before
|
||||
# the collector gives up and returns stale/empty results. Prevents the
|
||||
# /metrics endpoint from hanging indefinitely when a DB or Redis query stalls.
|
||||
_DEFAULT_COLLECT_TIMEOUT = 120.0
|
||||
|
||||
_QUEUE_LABEL_MAP: dict[str, str] = {
|
||||
OnyxCeleryQueues.PRIMARY: "primary",
|
||||
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
|
||||
@@ -62,18 +72,32 @@ _UNACKED_QUEUES: list[str] = [
|
||||
|
||||
|
||||
class _CachedCollector(Collector):
|
||||
"""Base collector with TTL-based caching.
|
||||
"""Base collector with TTL-based caching and timeout protection.
|
||||
|
||||
Subclasses implement ``_collect_fresh()`` to query the actual data source.
|
||||
The base ``collect()`` returns cached results if the TTL hasn't expired,
|
||||
avoiding repeated queries when Prometheus scrapes frequently.
|
||||
|
||||
A per-collection timeout prevents a slow DB or Redis query from blocking
|
||||
the /metrics endpoint indefinitely. If _collect_fresh() exceeds the
|
||||
timeout, stale cached results are returned instead.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl: float = _DEFAULT_CACHE_TTL,
|
||||
collect_timeout: float = _DEFAULT_COLLECT_TIMEOUT,
|
||||
) -> None:
|
||||
self._cache_ttl = cache_ttl
|
||||
self._collect_timeout = collect_timeout
|
||||
self._cached_result: list[GaugeMetricFamily] | None = None
|
||||
self._last_collect_time: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix=type(self).__name__,
|
||||
)
|
||||
self._inflight: concurrent.futures.Future | None = None
|
||||
|
||||
def collect(self) -> list[GaugeMetricFamily]:
|
||||
with self._lock:
|
||||
@@ -84,12 +108,28 @@ class _CachedCollector(Collector):
|
||||
):
|
||||
return self._cached_result
|
||||
|
||||
# If a previous _collect_fresh() is still running, wait on it
|
||||
# rather than queuing another. This prevents unbounded task
|
||||
# accumulation in the executor during extended DB outages.
|
||||
if self._inflight is not None and not self._inflight.done():
|
||||
future = self._inflight
|
||||
else:
|
||||
future = self._executor.submit(self._collect_fresh)
|
||||
self._inflight = future
|
||||
|
||||
try:
|
||||
result = self._collect_fresh()
|
||||
result = future.result(timeout=self._collect_timeout)
|
||||
self._inflight = None
|
||||
self._cached_result = result
|
||||
self._last_collect_time = now
|
||||
return result
|
||||
except concurrent.futures.TimeoutError:
|
||||
logger.warning(
|
||||
f"{type(self).__name__}._collect_fresh() timed out after {self._collect_timeout}s, returning stale cache"
|
||||
)
|
||||
return self._cached_result if self._cached_result is not None else []
|
||||
except Exception:
|
||||
self._inflight = None
|
||||
logger.exception(f"Error in {type(self).__name__}.collect()")
|
||||
# Return stale cache on error rather than nothing — avoids
|
||||
# metrics disappearing during transient failures.
|
||||
@@ -117,8 +157,6 @@ class QueueDepthCollector(_CachedCollector):
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
@@ -194,208 +232,6 @@ class QueueDepthCollector(_CachedCollector):
|
||||
return None
|
||||
|
||||
|
||||
class IndexAttemptCollector(_CachedCollector):
|
||||
"""Queries Postgres for index attempt state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
self._terminal_statuses: list = []
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
attempts_gauge = GaugeMetricFamily(
|
||||
"onyx_index_attempts_active",
|
||||
"Number of non-terminal index attempts",
|
||||
labels=[
|
||||
"status",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"connector_name",
|
||||
"cc_pair_id",
|
||||
],
|
||||
)
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
rows = get_active_index_attempts_for_metrics(session)
|
||||
|
||||
for status, source, cc_id, cc_name, count in rows:
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
attempts_gauge.add_metric(
|
||||
[
|
||||
status.value,
|
||||
source.value,
|
||||
tid,
|
||||
name_val,
|
||||
str(cc_id),
|
||||
],
|
||||
count,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [attempts_gauge]
|
||||
|
||||
|
||||
class ConnectorHealthCollector(_CachedCollector):
|
||||
"""Queries Postgres for connector health state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_health_for_metrics,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
|
||||
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
staleness_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"Seconds since last successful index for this connector",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
error_state_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
by_status_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_by_status",
|
||||
"Number of connectors grouped by status",
|
||||
labels=["tenant_id", "status"],
|
||||
)
|
||||
error_total_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_in_error_total",
|
||||
"Total number of connectors in repeated error state",
|
||||
labels=["tenant_id"],
|
||||
)
|
||||
per_connector_labels = [
|
||||
"tenant_id",
|
||||
"source",
|
||||
"cc_pair_id",
|
||||
"connector_name",
|
||||
]
|
||||
docs_success_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_docs_indexed",
|
||||
"Total new documents indexed (90-day rolling sum) per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
docs_error_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_error_count",
|
||||
"Total number of failed index attempts per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
pairs = get_connector_health_for_metrics(session)
|
||||
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
|
||||
docs_by_cc = get_docs_indexed_by_cc_pair(session)
|
||||
|
||||
status_counts: dict[str, int] = {}
|
||||
error_count = 0
|
||||
|
||||
for (
|
||||
cc_id,
|
||||
status,
|
||||
in_error,
|
||||
last_success,
|
||||
cc_name,
|
||||
source,
|
||||
) in pairs:
|
||||
cc_id_str = str(cc_id)
|
||||
source_val = source.value
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
label_vals = [tid, source_val, cc_id_str, name_val]
|
||||
|
||||
if last_success is not None:
|
||||
# Both `now` and `last_success` are timezone-aware
|
||||
# (the DB column uses DateTime(timezone=True)),
|
||||
# so subtraction is safe.
|
||||
age = (now - last_success).total_seconds()
|
||||
staleness_gauge.add_metric(label_vals, age)
|
||||
|
||||
error_state_gauge.add_metric(
|
||||
label_vals,
|
||||
1.0 if in_error else 0.0,
|
||||
)
|
||||
if in_error:
|
||||
error_count += 1
|
||||
|
||||
docs_success_gauge.add_metric(
|
||||
label_vals,
|
||||
docs_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
docs_error_gauge.add_metric(
|
||||
label_vals,
|
||||
error_counts_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
status_val = status.value
|
||||
status_counts[status_val] = status_counts.get(status_val, 0) + 1
|
||||
|
||||
for status_val, count in status_counts.items():
|
||||
by_status_gauge.add_metric([tid, status_val], count)
|
||||
|
||||
error_total_gauge.add_metric([tid], error_count)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [
|
||||
staleness_gauge,
|
||||
error_state_gauge,
|
||||
by_status_gauge,
|
||||
error_total_gauge,
|
||||
docs_success_gauge,
|
||||
docs_error_gauge,
|
||||
]
|
||||
|
||||
|
||||
class RedisHealthCollector(_CachedCollector):
|
||||
"""Collects Redis server health metrics (memory, clients, etc.)."""
|
||||
|
||||
@@ -411,8 +247,6 @@ class RedisHealthCollector(_CachedCollector):
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
@@ -495,7 +329,9 @@ class WorkerHeartbeatMonitor:
|
||||
},
|
||||
)
|
||||
recv.capture(
|
||||
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
|
||||
limit=None,
|
||||
timeout=self._HEARTBEAT_TIMEOUT_SECONDS,
|
||||
wakeup=True,
|
||||
)
|
||||
except Exception:
|
||||
if self._running:
|
||||
|
||||
@@ -6,8 +6,6 @@ Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
@@ -21,8 +19,6 @@ logger = setup_logger()
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
_attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
@@ -34,6 +30,9 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
|
||||
Note: connector health and index attempt metrics are push-based
|
||||
(see connector_health_metrics.py) and do not use collectors.
|
||||
"""
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
@@ -47,13 +46,8 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
_heartbeat_monitor.start()
|
||||
_worker_health_collector.set_monitor(_heartbeat_monitor)
|
||||
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
for collector in (
|
||||
_queue_collector,
|
||||
_attempt_collector,
|
||||
_connector_collector,
|
||||
_redis_health_collector,
|
||||
_worker_health_collector,
|
||||
):
|
||||
|
||||
@@ -7,6 +7,9 @@ from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEF
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_TENANT_WORK_GATING
|
||||
from onyx.configs.app_configs import TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import TENANT_WORK_GATING_TTL_SECONDS
|
||||
from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
@@ -139,6 +142,87 @@ class OnyxRuntime:
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _read_tenant_work_gating_flag(axis: str, default: bool) -> bool:
|
||||
"""Read `runtime:tenant_work_gating:{axis}` from Redis and interpret
|
||||
it as a bool. Returns `default` if the key is absent or unparseable.
|
||||
`axis` is either `enabled` (compute the gate) or `enforce` (actually
|
||||
skip)."""
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:tenant_work_gating:{axis}")
|
||||
if raw is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
return cast(bytes, raw).decode().strip().lower() == "true"
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_work_gating_enabled() -> bool:
|
||||
"""Should we *compute* the work gate? (read the Redis set, log how
|
||||
many tenants would be skipped). Env-var `ENABLE_TENANT_WORK_GATING`
|
||||
is the fallback default when no Redis override is set — it acts as
|
||||
the master switch that turns the feature on in shadow mode."""
|
||||
return OnyxRuntime._read_tenant_work_gating_flag(
|
||||
"enabled", default=ENABLE_TENANT_WORK_GATING
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_work_gating_enforce() -> bool:
|
||||
"""Should we *actually skip* tenants not in the work set?
|
||||
|
||||
Deliberately Redis-only with a hard-coded default of False: the env
|
||||
var `ENABLE_TENANT_WORK_GATING` only flips `enabled` (shadow mode),
|
||||
never `enforce`. Enforcement has to be turned on by an explicit
|
||||
`runtime:tenant_work_gating:enforce=true` write so ops can't
|
||||
accidentally skip real tenant traffic by flipping an env flag. Only
|
||||
meaningful when `get_tenant_work_gating_enabled()` is also True.
|
||||
"""
|
||||
return OnyxRuntime._read_tenant_work_gating_flag("enforce", default=False)
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_work_gating_ttl_seconds() -> int:
|
||||
"""Membership TTL for the `active_tenants` sorted set. Members older
|
||||
than this are treated as "no recent work" by the gate read path.
|
||||
Must be > (full-fanout cadence × base task schedule) so self-healing
|
||||
has time to refresh memberships before they expire."""
|
||||
default = TENANT_WORK_GATING_TTL_SECONDS
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:tenant_work_gating:ttl_seconds")
|
||||
if raw is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
value = int(cast(bytes, raw).decode())
|
||||
return value if value > 0 else default
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_work_gating_full_fanout_interval_seconds() -> int:
|
||||
"""Minimum wall-clock interval between full-fanout cycles. When at
|
||||
least this many seconds have elapsed since the last bypass, the
|
||||
generator ignores the gate on its next invocation and dispatches to
|
||||
every non-gated tenant, letting consumers re-populate the active
|
||||
set. Schedule-independent so beat drift or backlog can't skew the
|
||||
self-heal cadence."""
|
||||
default = TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
raw = r.get(
|
||||
f"{ONYX_CLOUD_REDIS_RUNTIME}:tenant_work_gating:full_fanout_interval_seconds"
|
||||
)
|
||||
if raw is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
value = int(cast(bytes, raw).decode())
|
||||
return value if value > 0 else default
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def get_build_fence_lookup_table_interval() -> int:
|
||||
"""We maintain an active fence table to make lookups of existing fences efficient.
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -113,10 +114,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
db_session: Session | None = None,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
@@ -131,6 +132,33 @@ def construct_tools(
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
return _construct_tools_impl(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
custom_tool_config=custom_tool_config,
|
||||
file_reader_tool_config=file_reader_tool_config,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
|
||||
|
||||
def _construct_tools_impl(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
|
||||
) -> dict[int, list[Tool]]:
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
|
||||
@@ -26,7 +26,7 @@ aiolimiter==1.2.1
|
||||
# via voyageai
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
alembic==1.10.4
|
||||
alembic==1.18.4
|
||||
amqp==5.3.1
|
||||
# via kombu
|
||||
annotated-doc==0.0.4
|
||||
@@ -299,7 +299,7 @@ h11==0.16.0
|
||||
# uvicorn
|
||||
h2==4.3.0
|
||||
# via httpx
|
||||
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
hpack==4.1.0
|
||||
# via h2
|
||||
@@ -322,6 +322,7 @@ httpx==0.28.1
|
||||
# fastmcp
|
||||
# google-genai
|
||||
# httpx-oauth
|
||||
# huggingface-hub
|
||||
# langfuse
|
||||
# langsmith
|
||||
# litellm
|
||||
@@ -334,7 +335,7 @@ httpx-sse==0.4.3
|
||||
# cohere
|
||||
# mcp
|
||||
hubspot-api-client==11.1.0
|
||||
huggingface-hub==0.35.3
|
||||
huggingface-hub==1.10.2
|
||||
# via tokenizers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
@@ -408,7 +409,7 @@ kombu==5.5.4
|
||||
# via celery
|
||||
kubernetes==31.0.0
|
||||
# via onyx
|
||||
langchain-core==1.2.22
|
||||
langchain-core==1.2.28
|
||||
langdetect==1.0.9
|
||||
# via unstructured
|
||||
langfuse==3.10.0
|
||||
@@ -589,7 +590,7 @@ platformdirs==4.5.0
|
||||
# via
|
||||
# fastmcp
|
||||
# zeep
|
||||
playwright==1.55.0
|
||||
playwright==1.58.0
|
||||
# via pytest-playwright
|
||||
pluggy==1.6.0
|
||||
# via pytest
|
||||
@@ -784,7 +785,6 @@ requests==2.33.0
|
||||
# google-api-core
|
||||
# google-genai
|
||||
# hubspot-api-client
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# jsonschema-path
|
||||
# kubernetes
|
||||
@@ -911,7 +911,7 @@ tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
tld==0.13.1
|
||||
# via courlan
|
||||
tokenizers==0.21.4
|
||||
tokenizers==0.22.2
|
||||
# via
|
||||
# chonkie
|
||||
# cohere
|
||||
@@ -933,7 +933,9 @@ tqdm==4.67.1
|
||||
# unstructured
|
||||
trafilatura==1.12.2
|
||||
typer==0.20.0
|
||||
# via mcp
|
||||
# via
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
types-awscrt==0.28.4
|
||||
# via botocore-stubs
|
||||
types-openpyxl==3.0.4.7
|
||||
|
||||
@@ -22,7 +22,7 @@ aiolimiter==1.2.1
|
||||
# via voyageai
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
alembic==1.10.4
|
||||
alembic==1.18.4
|
||||
# via pytest-alembic
|
||||
annotated-doc==0.0.4
|
||||
# via fastapi
|
||||
@@ -82,6 +82,7 @@ click==8.3.1
|
||||
# via
|
||||
# black
|
||||
# litellm
|
||||
# typer
|
||||
# uvicorn
|
||||
cohere==5.6.1
|
||||
# via onyx
|
||||
@@ -153,7 +154,7 @@ h11==0.16.0
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hatchling==1.28.0
|
||||
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
@@ -161,6 +162,7 @@ httpx==0.28.1
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# litellm
|
||||
# mcp
|
||||
# openai
|
||||
@@ -168,7 +170,7 @@ httpx-sse==0.4.3
|
||||
# via
|
||||
# cohere
|
||||
# mcp
|
||||
huggingface-hub==0.35.3
|
||||
huggingface-hub==1.10.2
|
||||
# via tokenizers
|
||||
identify==2.6.15
|
||||
# via pre-commit
|
||||
@@ -219,6 +221,8 @@ litellm==1.81.6
|
||||
mako==1.2.4
|
||||
# via alembic
|
||||
manygo==0.2.0
|
||||
markdown-it-py==4.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
@@ -230,6 +234,8 @@ matplotlib-inline==0.2.1
|
||||
# ipython
|
||||
mcp==1.26.0
|
||||
# via claude-agent-sdk
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -340,6 +346,7 @@ pygments==2.20.0
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
# rich
|
||||
pyjwt==2.12.0
|
||||
# via mcp
|
||||
pyparsing==3.2.5
|
||||
@@ -395,7 +402,6 @@ requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
# tiktoken
|
||||
@@ -404,6 +410,8 @@ requests-oauthlib==1.3.1
|
||||
# via kubernetes
|
||||
retry==0.9.2
|
||||
# via onyx
|
||||
rich==14.2.0
|
||||
# via typer
|
||||
rpds-py==0.29.0
|
||||
# via
|
||||
# jsonschema
|
||||
@@ -415,6 +423,8 @@ s3transfer==0.13.1
|
||||
# via boto3
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -442,7 +452,7 @@ tenacity==9.1.2
|
||||
# voyageai
|
||||
tiktoken==0.7.0
|
||||
# via litellm
|
||||
tokenizers==0.21.4
|
||||
tokenizers==0.22.2
|
||||
# via
|
||||
# cohere
|
||||
# litellm
|
||||
@@ -463,6 +473,8 @@ traitlets==5.14.3
|
||||
# matplotlib-inline
|
||||
trove-classifiers==2025.12.1.14
|
||||
# via hatchling
|
||||
typer==0.20.0
|
||||
# via huggingface-hub
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
# via types-beautifulsoup4
|
||||
@@ -500,6 +512,7 @@ typing-extensions==4.15.0
|
||||
# referencing
|
||||
# sqlalchemy
|
||||
# starlette
|
||||
# typer
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
|
||||
@@ -69,6 +69,7 @@ claude-agent-sdk==0.1.19
|
||||
click==8.3.1
|
||||
# via
|
||||
# litellm
|
||||
# typer
|
||||
# uvicorn
|
||||
cohere==5.6.1
|
||||
# via onyx
|
||||
@@ -112,7 +113,7 @@ h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
@@ -120,6 +121,7 @@ httpx==0.28.1
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# litellm
|
||||
# mcp
|
||||
# openai
|
||||
@@ -127,7 +129,7 @@ httpx-sse==0.4.3
|
||||
# via
|
||||
# cohere
|
||||
# mcp
|
||||
huggingface-hub==0.35.3
|
||||
huggingface-hub==1.10.2
|
||||
# via tokenizers
|
||||
idna==3.11
|
||||
# via
|
||||
@@ -156,10 +158,14 @@ kubernetes==31.0.0
|
||||
# via onyx
|
||||
litellm==1.81.6
|
||||
# via onyx
|
||||
markdown-it-py==4.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
mcp==1.26.0
|
||||
# via claude-agent-sdk
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
monotonic==1.6
|
||||
# via posthog
|
||||
multidict==6.7.0
|
||||
@@ -217,6 +223,8 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
pyjwt==2.12.0
|
||||
# via mcp
|
||||
python-dateutil==2.8.2
|
||||
@@ -247,7 +255,6 @@ requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
# posthog
|
||||
# requests-oauthlib
|
||||
@@ -257,6 +264,8 @@ requests-oauthlib==1.3.1
|
||||
# via kubernetes
|
||||
retry==0.9.2
|
||||
# via onyx
|
||||
rich==14.2.0
|
||||
# via typer
|
||||
rpds-py==0.29.0
|
||||
# via
|
||||
# jsonschema
|
||||
@@ -267,6 +276,8 @@ s3transfer==0.13.1
|
||||
# via boto3
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -289,7 +300,7 @@ tenacity==9.1.2
|
||||
# voyageai
|
||||
tiktoken==0.7.0
|
||||
# via litellm
|
||||
tokenizers==0.21.4
|
||||
tokenizers==0.22.2
|
||||
# via
|
||||
# cohere
|
||||
# litellm
|
||||
@@ -297,6 +308,8 @@ tqdm==4.67.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# openai
|
||||
typer==0.20.0
|
||||
# via huggingface-hub
|
||||
types-requests==2.32.0.20250328
|
||||
# via cohere
|
||||
typing-extensions==4.15.0
|
||||
@@ -313,6 +326,7 @@ typing-extensions==4.15.0
|
||||
# pydantic-core
|
||||
# referencing
|
||||
# starlette
|
||||
# typer
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
|
||||
@@ -78,6 +78,7 @@ click==8.3.1
|
||||
# click-plugins
|
||||
# click-repl
|
||||
# litellm
|
||||
# typer
|
||||
# uvicorn
|
||||
click-didyoumean==0.3.1
|
||||
# via celery
|
||||
@@ -116,7 +117,6 @@ filelock==3.20.3
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -135,7 +135,7 @@ h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
@@ -143,6 +143,7 @@ httpx==0.28.1
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# litellm
|
||||
# mcp
|
||||
# openai
|
||||
@@ -150,7 +151,7 @@ httpx-sse==0.4.3
|
||||
# via
|
||||
# cohere
|
||||
# mcp
|
||||
huggingface-hub==0.35.3
|
||||
huggingface-hub==1.10.2
|
||||
# via
|
||||
# accelerate
|
||||
# sentence-transformers
|
||||
@@ -189,10 +190,14 @@ kubernetes==31.0.0
|
||||
# via onyx
|
||||
litellm==1.81.6
|
||||
# via onyx
|
||||
markdown-it-py==4.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
mcp==1.26.0
|
||||
# via claude-agent-sdk
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.7.0
|
||||
@@ -207,6 +212,7 @@ numpy==2.4.1
|
||||
# accelerate
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# sentence-transformers
|
||||
# transformers
|
||||
# voyageai
|
||||
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
@@ -264,8 +270,6 @@ packaging==24.2
|
||||
# transformers
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
pillow==12.2.0
|
||||
# via sentence-transformers
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -305,6 +309,8 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
pyjwt==2.12.0
|
||||
# via mcp
|
||||
python-dateutil==2.8.2
|
||||
@@ -339,16 +345,16 @@ requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
# huggingface-hub
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
# tiktoken
|
||||
# transformers
|
||||
# voyageai
|
||||
requests-oauthlib==1.3.1
|
||||
# via kubernetes
|
||||
retry==0.9.2
|
||||
# via onyx
|
||||
rich==14.2.0
|
||||
# via typer
|
||||
rpds-py==0.29.0
|
||||
# via
|
||||
# jsonschema
|
||||
@@ -367,11 +373,13 @@ scipy==1.16.3
|
||||
# via
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
sentence-transformers==4.0.2
|
||||
sentence-transformers==5.4.1
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
setuptools==80.9.0 ; python_full_version >= '3.12'
|
||||
# via torch
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.17.0
|
||||
# via
|
||||
# kubernetes
|
||||
@@ -398,7 +406,7 @@ threadpoolctl==3.6.0
|
||||
# via scikit-learn
|
||||
tiktoken==0.7.0
|
||||
# via litellm
|
||||
tokenizers==0.21.4
|
||||
tokenizers==0.22.2
|
||||
# via
|
||||
# cohere
|
||||
# litellm
|
||||
@@ -413,10 +421,14 @@ tqdm==4.67.1
|
||||
# openai
|
||||
# sentence-transformers
|
||||
# transformers
|
||||
transformers==4.53.0
|
||||
transformers==5.5.4
|
||||
# via sentence-transformers
|
||||
triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
typer==0.20.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
types-requests==2.32.0.20250328
|
||||
# via cohere
|
||||
typing-extensions==4.15.0
|
||||
@@ -435,6 +447,7 @@ typing-extensions==4.15.0
|
||||
# sentence-transformers
|
||||
# starlette
|
||||
# torch
|
||||
# typer
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
@@ -699,3 +700,43 @@ def test_specific_user_email_shared_with_me(
|
||||
|
||||
doc_titles = set(doc.semantic_identifier for doc in output.documents)
|
||||
assert doc_titles == set(expected)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_slim_retrieval_does_not_call_permissions_list(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""retrieve_all_slim_docs() must not call permissions().list for any file.
|
||||
|
||||
Pruning only needs file IDs — fetching permissions per file causes O(N) API
|
||||
calls that time out for tenants with large numbers of externally-owned files.
|
||||
"""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
include_my_drives=True,
|
||||
include_files_shared_with_me=False,
|
||||
shared_folder_urls=None,
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.connectors.google_drive.connector.execute_paginated_retrieval",
|
||||
wraps=execute_paginated_retrieval,
|
||||
) as mock_paginated:
|
||||
for batch in connector.retrieve_all_slim_docs():
|
||||
pass
|
||||
|
||||
permissions_calls = [
|
||||
c
|
||||
for c in mock_paginated.call_args_list
|
||||
if "permissions" in str(c.kwargs.get("retrieval_function", ""))
|
||||
]
|
||||
assert (
|
||||
len(permissions_calls) == 0
|
||||
), f"permissions().list was called {len(permissions_calls)} time(s) during pruning"
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
@@ -141,13 +142,15 @@ def load_all_from_connector(
|
||||
|
||||
def to_sections(
|
||||
documents: list[Document],
|
||||
) -> Iterator[TextSection | ImageSection]:
|
||||
) -> Iterator[TextSection | ImageSection | TabularSection]:
|
||||
for doc in documents:
|
||||
for section in doc.sections:
|
||||
yield section
|
||||
|
||||
|
||||
def to_text_sections(sections: Iterator[TextSection | ImageSection]) -> Iterator[str]:
|
||||
def to_text_sections(
|
||||
sections: Iterator[TextSection | ImageSection | TabularSection],
|
||||
) -> Iterator[str]:
|
||||
for section in sections:
|
||||
if isinstance(section, TextSection):
|
||||
yield section.text
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
"""Tests for the tenant work-gating Redis helpers.
|
||||
|
||||
Requires a running Redis instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest \
|
||||
backend/tests/external_dependency_unit/tenant_work_gating/test_tenant_work_gating.py
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis import redis_tenant_work_gating as twg
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_tenant_work_gating import _SET_KEY
|
||||
from onyx.redis.redis_tenant_work_gating import cleanup_expired
|
||||
from onyx.redis.redis_tenant_work_gating import get_active_tenants
|
||||
from onyx.redis.redis_tenant_work_gating import mark_tenant_active
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _multi_tenant_true() -> Generator[None, None, None]:
|
||||
"""Force MULTI_TENANT=True for the helper module so public functions are
|
||||
not no-ops during tests."""
|
||||
with patch.object(twg, "MULTI_TENANT", True):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_set() -> Generator[None, None, None]:
|
||||
"""Clear the active_tenants sorted set before and after each test."""
|
||||
client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
client.delete(_SET_KEY)
|
||||
yield
|
||||
client.delete(_SET_KEY)
|
||||
|
||||
|
||||
def test_mark_adds_tenant_to_set() -> None:
|
||||
mark_tenant_active("tenant_a")
|
||||
|
||||
assert get_active_tenants(ttl_seconds=60) == {"tenant_a"}
|
||||
|
||||
|
||||
def test_mark_refreshes_timestamp() -> None:
|
||||
"""ZADD overwrites the score on existing members. Without a refresh,
|
||||
reading with a TTL that excludes the first write should return empty;
|
||||
after a second mark_tenant_active at a newer timestamp, the same TTL
|
||||
read should include the tenant. Pins `_now_ms` so the test is
|
||||
deterministic."""
|
||||
base_ms = int(time.time() * 1000)
|
||||
|
||||
# First write at t=0.
|
||||
with patch.object(twg, "_now_ms", return_value=base_ms):
|
||||
mark_tenant_active("tenant_a")
|
||||
|
||||
# Read 5s later with a 1s TTL — first write is outside the window.
|
||||
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
|
||||
assert get_active_tenants(ttl_seconds=1) == set()
|
||||
|
||||
# Refresh at t=5s.
|
||||
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
|
||||
mark_tenant_active("tenant_a")
|
||||
|
||||
# Read at t=5s with a 1s TTL — refreshed write is inside the window.
|
||||
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
|
||||
assert get_active_tenants(ttl_seconds=1) == {"tenant_a"}
|
||||
|
||||
|
||||
def test_get_active_tenants_filters_by_ttl() -> None:
|
||||
"""Tenant marked in the past, read with a TTL short enough to exclude it."""
|
||||
# Pin _now_ms so the write happens at t=0 and the read cutoff is
|
||||
# well after that.
|
||||
base_ms = int(time.time() * 1000)
|
||||
with patch.object(twg, "_now_ms", return_value=base_ms):
|
||||
mark_tenant_active("tenant_old")
|
||||
|
||||
# Read 5 seconds later with a 1-second TTL — tenant_old is outside.
|
||||
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
|
||||
assert get_active_tenants(ttl_seconds=1) == set()
|
||||
|
||||
# Read 5 seconds later with a 10-second TTL — tenant_old is inside.
|
||||
with patch.object(twg, "_now_ms", return_value=base_ms + 5000):
|
||||
assert get_active_tenants(ttl_seconds=10) == {"tenant_old"}
|
||||
|
||||
|
||||
def test_get_active_tenants_multiple_members() -> None:
|
||||
mark_tenant_active("tenant_a")
|
||||
mark_tenant_active("tenant_b")
|
||||
mark_tenant_active("tenant_c")
|
||||
|
||||
assert get_active_tenants(ttl_seconds=60) == {"tenant_a", "tenant_b", "tenant_c"}
|
||||
|
||||
|
||||
def test_get_active_tenants_empty_set() -> None:
|
||||
"""Genuinely-empty set returns an empty set (not None)."""
|
||||
assert get_active_tenants(ttl_seconds=60) == set()
|
||||
|
||||
|
||||
def test_get_active_tenants_returns_none_on_redis_error() -> None:
|
||||
"""Callers need to distinguish Redis failure from "no tenants active" so
|
||||
they can fail open. Simulate failure by patching the client to raise."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
failing_client = MagicMock()
|
||||
failing_client.zrangebyscore.side_effect = RuntimeError("simulated outage")
|
||||
|
||||
with patch.object(twg, "_client", return_value=failing_client):
|
||||
assert get_active_tenants(ttl_seconds=60) is None
|
||||
|
||||
|
||||
def test_get_active_tenants_returns_none_in_single_tenant_mode() -> None:
|
||||
"""Single-tenant mode returns None so callers can skip the gate entirely
|
||||
(same fail-open handling as Redis unavailability)."""
|
||||
with patch.object(twg, "MULTI_TENANT", False):
|
||||
assert get_active_tenants(ttl_seconds=60) is None
|
||||
|
||||
|
||||
def test_cleanup_expired_removes_only_stale_members() -> None:
|
||||
"""Seed one stale and one fresh member directly; cleanup should drop only
|
||||
the stale one."""
|
||||
now_ms = int(time.time() * 1000)
|
||||
|
||||
client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
client.zadd(_SET_KEY, mapping={"tenant_old": now_ms - 10 * 60 * 1000})
|
||||
client.zadd(_SET_KEY, mapping={"tenant_new": now_ms})
|
||||
|
||||
removed = cleanup_expired(ttl_seconds=60)
|
||||
|
||||
assert removed == 1
|
||||
assert get_active_tenants(ttl_seconds=60 * 60) == {"tenant_new"}
|
||||
|
||||
|
||||
def test_cleanup_expired_empty_set_noop() -> None:
|
||||
assert cleanup_expired(ttl_seconds=60) == 0
|
||||
|
||||
|
||||
def test_noop_when_multi_tenant_false() -> None:
|
||||
with patch.object(twg, "MULTI_TENANT", False):
|
||||
mark_tenant_active("tenant_a")
|
||||
assert get_active_tenants(ttl_seconds=60) is None
|
||||
assert cleanup_expired(ttl_seconds=60) == 0
|
||||
|
||||
# Verify nothing was written while MULTI_TENANT was False.
|
||||
assert get_active_tenants(ttl_seconds=60) == set()
|
||||
|
||||
|
||||
def test_rendered_key_is_cloud_prefixed() -> None:
|
||||
"""Exercises TenantRedis auto-prefixing on sorted-set ops. The rendered
|
||||
Redis key should be `cloud:active_tenants`, not bare `active_tenants`."""
|
||||
mark_tenant_active("tenant_a")
|
||||
|
||||
from onyx.redis.redis_pool import RedisPool
|
||||
|
||||
raw = RedisPool().get_raw_client()
|
||||
assert raw.zscore("cloud:active_tenants", "tenant_a") is not None
|
||||
assert raw.zscore("active_tenants", "tenant_a") is None
|
||||
@@ -38,38 +38,41 @@ class TestAddMemory:
|
||||
def test_add_memory_creates_row(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that add_memory inserts a new Memory row."""
|
||||
user_id = test_user.id
|
||||
memory = add_memory(
|
||||
memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="User prefers dark mode",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert memory.id is not None
|
||||
assert memory.user_id == user_id
|
||||
assert memory.memory_text == "User prefers dark mode"
|
||||
assert memory_id is not None
|
||||
|
||||
# Verify it persists
|
||||
fetched = db_session.get(Memory, memory.id)
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.user_id == user_id
|
||||
assert fetched.memory_text == "User prefers dark mode"
|
||||
|
||||
def test_add_multiple_memories(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that multiple memories can be added for the same user."""
|
||||
user_id = test_user.id
|
||||
m1 = add_memory(
|
||||
m1_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Favorite color is blue",
|
||||
db_session=db_session,
|
||||
)
|
||||
m2 = add_memory(
|
||||
m2_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Works in engineering",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert m1.id != m2.id
|
||||
assert m1.memory_text == "Favorite color is blue"
|
||||
assert m2.memory_text == "Works in engineering"
|
||||
assert m1_id != m2_id
|
||||
fetched_m1 = db_session.get(Memory, m1_id)
|
||||
fetched_m2 = db_session.get(Memory, m2_id)
|
||||
assert fetched_m1 is not None
|
||||
assert fetched_m2 is not None
|
||||
assert fetched_m1.memory_text == "Favorite color is blue"
|
||||
assert fetched_m2.memory_text == "Works in engineering"
|
||||
|
||||
|
||||
class TestUpdateMemoryAtIndex:
|
||||
@@ -82,15 +85,17 @@ class TestUpdateMemoryAtIndex:
|
||||
add_memory(user_id=user_id, memory_text="Memory 1", db_session=db_session)
|
||||
add_memory(user_id=user_id, memory_text="Memory 2", db_session=db_session)
|
||||
|
||||
updated = update_memory_at_index(
|
||||
updated_id = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=1,
|
||||
new_text="Updated Memory 1",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated Memory 1"
|
||||
assert updated_id is not None
|
||||
fetched = db_session.get(Memory, updated_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Updated Memory 1"
|
||||
|
||||
def test_update_memory_at_out_of_range_index(
|
||||
self, db_session: Session, test_user: User
|
||||
@@ -167,7 +172,7 @@ class TestMemoryCap:
|
||||
assert len(rows_before) == MAX_MEMORIES_PER_USER
|
||||
|
||||
# Add one more — should evict the oldest
|
||||
new_memory = add_memory(
|
||||
new_memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="New memory after cap",
|
||||
db_session=db_session,
|
||||
@@ -181,7 +186,7 @@ class TestMemoryCap:
|
||||
# Oldest ("Memory 0") should be gone; "Memory 1" is now the oldest
|
||||
assert rows_after[0].memory_text == "Memory 1"
|
||||
# Newest should be the one we just added
|
||||
assert rows_after[-1].id == new_memory.id
|
||||
assert rows_after[-1].id == new_memory_id
|
||||
assert rows_after[-1].memory_text == "New memory after cap"
|
||||
|
||||
|
||||
@@ -221,22 +226,26 @@ class TestGetMemoriesWithUserId:
|
||||
user_id = test_user_no_memories.id
|
||||
|
||||
# Add a memory
|
||||
memory = add_memory(
|
||||
memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert memory.memory_text == "Memory with use_memories off"
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Memory with use_memories off"
|
||||
|
||||
# Update that memory
|
||||
updated = update_memory_at_index(
|
||||
updated_id = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=0,
|
||||
new_text="Updated memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated memory with use_memories off"
|
||||
assert updated_id is not None
|
||||
fetched_updated = db_session.get(Memory, updated_id)
|
||||
assert fetched_updated is not None
|
||||
assert fetched_updated.memory_text == "Updated memory with use_memories off"
|
||||
|
||||
# Verify get_memories returns the updated memory
|
||||
context = get_memories(test_user_no_memories, db_session)
|
||||
|
||||
@@ -12,7 +12,7 @@ from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import DATestAPIKey
|
||||
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import SimpleTestDocument
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
@@ -62,6 +62,7 @@ class DocumentSetManager:
|
||||
) -> bool:
|
||||
doc_set_update_request = {
|
||||
"id": document_set.id,
|
||||
"name": document_set.name,
|
||||
"description": document_set.description,
|
||||
"cc_pair_ids": document_set.cc_pair_ids,
|
||||
"is_public": document_set.is_public,
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.image_generation import (
|
||||
ImageGenerationConfigManager,
|
||||
@@ -196,6 +195,9 @@ def image_generation_config(
|
||||
|
||||
@pytest.fixture
|
||||
def document_builder(admin_user: DATestUser) -> DocumentBuilderType:
|
||||
# HACK: Avoid importing generated OpenAPI client modules unless this fixture is used.
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -108,12 +108,12 @@ def current_head_rev() -> str:
|
||||
["alembic", "heads", "--resolve-dependencies"],
|
||||
cwd=_BACKEND_DIR,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}"
|
||||
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}\n{result.stderr}"
|
||||
# Output looks like "d5c86e2c6dc6 (head)\n"
|
||||
rev = result.stdout.strip().split()[0]
|
||||
assert len(rev) > 0
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
@@ -159,3 +161,58 @@ def test_removing_connector(
|
||||
doc_set_names=[],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
|
||||
def test_renaming_document_set(
|
||||
reset: None, # noqa: ARG001
|
||||
vespa_client: vespa_fixture,
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair.documents = DocumentManager.seed_dummy_docs(
|
||||
cc_pair=cc_pair,
|
||||
num_docs=NUM_DOCS,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
original_name = f"original_doc_set_{uuid4()}"
|
||||
doc_set = DocumentSetManager.create(
|
||||
name=original_name,
|
||||
cc_pair_ids=[cc_pair.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
DocumentSetManager.wait_for_sync(user_performing_action=admin_user)
|
||||
DocumentSetManager.verify(
|
||||
document_set=doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
new_name = f"renamed_doc_set_{uuid4()}"
|
||||
doc_set.name = new_name
|
||||
DocumentSetManager.edit(
|
||||
doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
DocumentSetManager.wait_for_sync(user_performing_action=admin_user)
|
||||
DocumentSetManager.verify(
|
||||
document_set=doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair,
|
||||
doc_set_names=[new_name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Integration tests verifying the knowledge_sources field on MinimalPersonaSnapshot.
|
||||
|
||||
The GET /persona endpoint returns MinimalPersonaSnapshot, which includes a
|
||||
knowledge_sources list derived from the persona's document sets, hierarchy
|
||||
nodes, attached documents, and user files. These tests verify that the
|
||||
field is populated correctly.
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_file_utils import create_test_text_file
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_minimal_persona(
|
||||
persona_id: int,
|
||||
user: DATestUser,
|
||||
) -> dict:
|
||||
"""Fetch personas from the list endpoint and find the one with the given id."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
params={"persona_ids": persona_id},
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
personas = response.json()
|
||||
matches = [p for p in personas if p["id"] == persona_id]
|
||||
assert (
|
||||
len(matches) == 1
|
||||
), f"Expected 1 persona with id={persona_id}, got {len(matches)}"
|
||||
return matches[0]
|
||||
|
||||
|
||||
def test_persona_with_user_files_includes_user_file_source(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a persona has user files attached, knowledge_sources includes 'user_file'."""
|
||||
text_file = create_test_text_file("test content for knowledge source verification")
|
||||
file_descriptors, error = FileManager.upload_files(
|
||||
files=[("test_ks.txt", text_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not error, f"File upload failed: {error}"
|
||||
|
||||
user_file_id = file_descriptors[0]["user_file_id"] or ""
|
||||
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="KS User File Agent",
|
||||
description="Agent with user files for knowledge_sources test",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_file_ids=[user_file_id],
|
||||
)
|
||||
|
||||
minimal = _get_minimal_persona(persona.id, admin_user)
|
||||
assert (
|
||||
DocumentSource.USER_FILE.value in minimal["knowledge_sources"]
|
||||
), f"Expected 'user_file' in knowledge_sources, got: {minimal['knowledge_sources']}"
|
||||
|
||||
|
||||
def test_persona_without_user_files_excludes_user_file_source(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a persona has no user files, knowledge_sources should not contain 'user_file'."""
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="KS No Files Agent",
|
||||
description="Agent without files for knowledge_sources test",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
minimal = _get_minimal_persona(persona.id, admin_user)
|
||||
assert (
|
||||
DocumentSource.USER_FILE.value not in minimal["knowledge_sources"]
|
||||
), f"Unexpected 'user_file' in knowledge_sources: {minimal['knowledge_sources']}"
|
||||
@@ -301,7 +301,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -332,7 +331,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -363,7 +361,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -391,7 +388,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -423,7 +419,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -456,7 +451,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -497,7 +491,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -519,7 +512,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -542,7 +534,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -596,7 +587,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -653,7 +643,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -706,7 +695,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -736,7 +724,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
|
||||
0
backend/tests/unit/onyx/configs/__init__.py
Normal file
0
backend/tests/unit/onyx/configs/__init__.py
Normal file
88
backend/tests/unit/onyx/configs/test_sentry.py
Normal file
88
backend/tests/unit/onyx/configs/test_sentry.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from sentry_sdk.types import Event
|
||||
|
||||
import onyx.configs.sentry as sentry_module
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
|
||||
def _event(data: dict) -> Event:
|
||||
"""Helper to create a Sentry Event from a plain dict for testing."""
|
||||
return cast(Event, data)
|
||||
|
||||
|
||||
def _reset_state() -> None:
|
||||
"""Reset the module-level resolved flag between tests."""
|
||||
sentry_module._instance_id_resolved = False
|
||||
|
||||
|
||||
class TestAddInstanceTags:
|
||||
def setup_method(self) -> None:
|
||||
_reset_state()
|
||||
|
||||
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_first_event_sets_instance_id(
|
||||
self, mock_set_tag: MagicMock, mock_uuid: MagicMock
|
||||
) -> None:
|
||||
result = _add_instance_tags(_event({"message": "test error"}), {})
|
||||
|
||||
assert result is not None
|
||||
assert result["tags"]["instance_id"] == "test-uuid-1234"
|
||||
mock_set_tag.assert_called_once_with("instance_id", "test-uuid-1234")
|
||||
mock_uuid.assert_called_once()
|
||||
|
||||
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_second_event_skips_resolution(
|
||||
self, _mock_set_tag: MagicMock, mock_uuid: MagicMock
|
||||
) -> None:
|
||||
_add_instance_tags(_event({"message": "first"}), {})
|
||||
result = _add_instance_tags(_event({"message": "second"}), {})
|
||||
|
||||
assert result is not None
|
||||
assert "tags" not in result # second event not modified
|
||||
mock_uuid.assert_called_once() # only resolved once
|
||||
|
||||
@patch(
|
||||
"onyx.utils.telemetry.get_or_generate_uuid",
|
||||
side_effect=Exception("DB unavailable"),
|
||||
)
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_resolution_failure_still_returns_event(
|
||||
self, _mock_set_tag: MagicMock, _mock_uuid: MagicMock
|
||||
) -> None:
|
||||
result = _add_instance_tags(_event({"message": "test error"}), {})
|
||||
|
||||
assert result is not None
|
||||
assert result["message"] == "test error"
|
||||
assert "tags" not in result or "instance_id" not in result.get("tags", {})
|
||||
|
||||
@patch(
|
||||
"onyx.utils.telemetry.get_or_generate_uuid",
|
||||
side_effect=Exception("DB unavailable"),
|
||||
)
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_resolution_failure_retries_on_next_event(
|
||||
self, _mock_set_tag: MagicMock, mock_uuid: MagicMock
|
||||
) -> None:
|
||||
"""If resolution fails (e.g. DB not ready), retry on the next event."""
|
||||
_add_instance_tags(_event({"message": "first"}), {})
|
||||
_add_instance_tags(_event({"message": "second"}), {})
|
||||
|
||||
assert mock_uuid.call_count == 2 # retried on second event
|
||||
|
||||
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_preserves_existing_tags(
|
||||
self, _mock_set_tag: MagicMock, _mock_uuid: MagicMock
|
||||
) -> None:
|
||||
result = _add_instance_tags(
|
||||
_event({"message": "test", "tags": {"existing": "tag"}}), {}
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["tags"]["existing"] == "tag"
|
||||
assert result["tags"]["instance_id"] == "test-uuid-1234"
|
||||
@@ -8,14 +8,23 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.connectors.canvas.connector import _in_time_window
|
||||
from onyx.connectors.canvas.connector import _parse_canvas_dt
|
||||
from onyx.connectors.canvas.connector import _unix_to_canvas_time
|
||||
from onyx.connectors.canvas.connector import CanvasConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.canvas.connector import CanvasConnectorCheckpoint
|
||||
from onyx.connectors.canvas.connector import CanvasStage
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -111,6 +120,56 @@ def _mock_response(
|
||||
return resp
|
||||
|
||||
|
||||
def _make_url_dispatcher(
|
||||
courses: list[dict[str, Any]] | None = None,
|
||||
pages: list[dict[str, Any]] | None = None,
|
||||
assignments: list[dict[str, Any]] | None = None,
|
||||
announcements: list[dict[str, Any]] | None = None,
|
||||
page_error: bool = False,
|
||||
) -> Any:
|
||||
"""Return a callable that dispatches mock responses based on the request URL.
|
||||
|
||||
Meant to be assigned to ``mock_requests.get.side_effect``.
|
||||
"""
|
||||
api_prefix = f"{FAKE_BASE_URL}/api/v1"
|
||||
|
||||
def _dispatcher(url: str, **_kwargs: Any) -> MagicMock:
|
||||
if page_error:
|
||||
return _mock_response(500, {})
|
||||
if url == f"{api_prefix}/courses":
|
||||
return _mock_response(json_data=courses or [])
|
||||
if "/pages" in url:
|
||||
return _mock_response(json_data=pages or [])
|
||||
if "/assignments" in url:
|
||||
return _mock_response(json_data=assignments or [])
|
||||
if "announcements" in url:
|
||||
return _mock_response(json_data=announcements or [])
|
||||
return _mock_response(json_data=[])
|
||||
|
||||
return _dispatcher
|
||||
|
||||
|
||||
def _run_checkpoint(
|
||||
connector: CanvasConnector,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
start: float = 0.0,
|
||||
end: float = datetime(2099, 1, 1, tzinfo=timezone.utc).timestamp(),
|
||||
) -> tuple[
|
||||
list[Document | HierarchyNode | ConnectorFailure], CanvasConnectorCheckpoint
|
||||
]:
|
||||
"""Run load_from_checkpoint once and collect yielded items + returned checkpoint."""
|
||||
gen = connector.load_from_checkpoint(start, end, checkpoint)
|
||||
items: list[Document | HierarchyNode | ConnectorFailure] = []
|
||||
new_checkpoint: CanvasConnectorCheckpoint | None = None
|
||||
try:
|
||||
while True:
|
||||
items.append(next(gen))
|
||||
except StopIteration as e:
|
||||
new_checkpoint = e.value
|
||||
assert new_checkpoint is not None
|
||||
return items, new_checkpoint
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient.__init__ tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -269,15 +328,6 @@ class TestGet:
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_raises_on_429(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(429, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_skips_params_when_using_full_url(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
@@ -454,6 +504,149 @@ class TestPaginate:
|
||||
|
||||
assert pages == []
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_error_dict(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 1: {"error": {"message": "Not authorized"}}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
403, {"error": {"message": "Not authorized"}}
|
||||
)
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Not authorized"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_error_string(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 2: {"error": "Invalid access token"}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
401, {"error": "Invalid access token"}
|
||||
)
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Invalid access token"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_errors_list(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 3: {"errors": [{"message": "Invalid query"}]}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
400, {"errors": [{"message": "Invalid query"}]}
|
||||
)
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Invalid query"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_dict_takes_priority_over_errors_list(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""When both error shapes are present, error dict wins."""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
403, {"error": "Specific error", "errors": [{"message": "Generic"}]}
|
||||
)
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Specific error"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_falls_back_to_reason_when_no_json_message(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Empty error body falls back to response.reason."""
|
||||
mock_requests.get.return_value = _mock_response(500, {})
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Error" # from _mock_response's reason for >= 300
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_invalid_json_on_success_raises(self, mock_requests: MagicMock) -> None:
|
||||
"""Invalid JSON on a 2xx response raises OnyxError."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
resp.headers = {"Link": ""}
|
||||
mock_requests.get.return_value = resp
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError, match="Invalid JSON"):
|
||||
client.get("courses")
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_invalid_json_on_error_falls_back_to_reason(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Invalid JSON on a 4xx response falls back to response.reason."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 500
|
||||
resp.reason = "Internal Server Error"
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
resp.headers = {"Link": ""}
|
||||
mock_requests.get.return_value = resp
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Internal Server Error"
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._parse_next_link tests
|
||||
@@ -588,6 +781,16 @@ class TestConnectorUrlNormalization:
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_insufficient_permissions(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
mock_requests.get.return_value = _mock_response(403, {})
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
|
||||
with pytest.raises(InsufficientPermissionsError):
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — document conversion
|
||||
@@ -766,10 +969,6 @@ class TestValidateConnectorSettings:
|
||||
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
|
||||
@@ -874,3 +1073,652 @@ class TestListAnnouncements:
|
||||
result = connector._list_announcements(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestCheckpoint:
|
||||
def test_build_dummy_checkpoint(self) -> None:
|
||||
connector = _build_connector()
|
||||
|
||||
cp = connector.build_dummy_checkpoint()
|
||||
|
||||
assert cp.has_more is True
|
||||
assert cp.course_ids == []
|
||||
assert cp.current_course_index == 0
|
||||
assert cp.stage == CanvasStage.PAGES
|
||||
|
||||
def test_validate_checkpoint_json(self) -> None:
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1, 2],
|
||||
current_course_index=1,
|
||||
stage=CanvasStage.ASSIGNMENTS,
|
||||
)
|
||||
|
||||
json_str = cp.model_dump_json()
|
||||
restored = connector.validate_checkpoint_json(json_str)
|
||||
|
||||
assert restored.course_ids == [1, 2]
|
||||
assert restored.current_course_index == 1
|
||||
assert restored.stage == CanvasStage.ASSIGNMENTS
|
||||
assert restored.has_more is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_from_checkpoint tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadFromCheckpoint:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_first_call_materializes_courses(self, mock_requests: MagicMock) -> None:
|
||||
"""First call should populate course_ids and yield no documents."""
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(
|
||||
courses=[_mock_course(1), _mock_course(2, "Data Structures", "CS201")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
cp = connector.build_dummy_checkpoint()
|
||||
|
||||
items, new_cp = _run_checkpoint(connector, cp)
|
||||
|
||||
assert items == []
|
||||
assert new_cp.course_ids == [1, 2]
|
||||
assert new_cp.current_course_index == 0
|
||||
assert new_cp.stage == CanvasStage.PAGES
|
||||
assert new_cp.has_more is True
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_processes_pages_stage(self, mock_requests: MagicMock) -> None:
|
||||
"""Pages stage yields page documents within the time window."""
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(
|
||||
pages=[_mock_page(10, "Syllabus", "2025-06-15T12:00:00Z")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
start = datetime(2025, 6, 1, 0, 0, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, 0, 0, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
items, new_cp = _run_checkpoint(connector, cp, start, end)
|
||||
|
||||
expected_count = 1
|
||||
expected_id = "canvas-page-1-10"
|
||||
assert len(items) == expected_count
|
||||
assert isinstance(items[0], Document)
|
||||
assert items[0].id == expected_id
|
||||
assert new_cp.stage == CanvasStage.ASSIGNMENTS
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_advances_through_all_stages(self, mock_requests: MagicMock) -> None:
|
||||
"""Calling checkpoint 3 times advances pages -> assignments -> announcements -> next course."""
|
||||
page = _mock_page(10, updated_at="2025-06-15T12:00:00Z")
|
||||
assignment = _mock_assignment(20, updated_at="2025-06-15T12:00:00Z")
|
||||
announcement = _mock_announcement(30, posted_at="2025-06-15T12:00:00Z")
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(
|
||||
pages=[page], assignments=[assignment], announcements=[announcement]
|
||||
)
|
||||
connector = _build_connector()
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
|
||||
# Stage 1: pages
|
||||
items1, cp = _run_checkpoint(connector, cp, start, end)
|
||||
|
||||
assert cp.stage == CanvasStage.ASSIGNMENTS
|
||||
assert len(items1) == 1
|
||||
|
||||
# Stage 2: assignments
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(assignments=[assignment])
|
||||
|
||||
items2, cp = _run_checkpoint(connector, cp, start, end)
|
||||
|
||||
assert cp.stage == CanvasStage.ANNOUNCEMENTS
|
||||
assert len(items2) == 1
|
||||
|
||||
# Stage 3: announcements -> advances course index
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(
|
||||
announcements=[announcement]
|
||||
)
|
||||
|
||||
items3, cp = _run_checkpoint(connector, cp, start, end)
|
||||
|
||||
assert cp.current_course_index == 1
|
||||
assert cp.stage == CanvasStage.PAGES
|
||||
assert cp.has_more is False
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_filters_by_time_window(self, mock_requests: MagicMock) -> None:
|
||||
"""Only documents within (start, end] are yielded."""
|
||||
old_page = _mock_page(10, updated_at="2025-01-01T00:00:00Z")
|
||||
new_page = _mock_page(11, title="New Page", updated_at="2025-06-15T12:00:00Z")
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(pages=[new_page, old_page])
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
items, _ = _run_checkpoint(connector, cp, start, end)
|
||||
|
||||
expected_count = 1
|
||||
expected_id = "canvas-page-1-11"
|
||||
|
||||
assert len(items) == expected_count
|
||||
assert isinstance(items[0], Document)
|
||||
assert items[0].id == expected_id
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_skips_announcement_without_posted_at(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
announcement = _mock_announcement()
|
||||
announcement["posted_at"] = None
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(
|
||||
announcements=[announcement]
|
||||
)
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.ANNOUNCEMENTS,
|
||||
)
|
||||
|
||||
items, _ = _run_checkpoint(connector, cp)
|
||||
|
||||
assert len(items) == 0
|
||||
|
||||
def test_stage_failure_advances_stage_and_yields_failure(self) -> None:
|
||||
"""A 500 on a stage fetch yields a stage-level ConnectorFailure and
|
||||
advances to the next stage, so the framework doesn't loop on the
|
||||
same failing state forever."""
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1, 2],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_fetch_stage_page",
|
||||
side_effect=OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"boom",
|
||||
status_code_override=500,
|
||||
),
|
||||
):
|
||||
items, new_cp = _run_checkpoint(connector, cp)
|
||||
|
||||
expected_entity_id = "canvas-pages-1"
|
||||
assert len(items) == 1
|
||||
assert isinstance(items[0], ConnectorFailure)
|
||||
assert items[0].failed_entity is not None
|
||||
assert items[0].failed_entity.entity_id == expected_entity_id
|
||||
assert new_cp.stage == CanvasStage.ASSIGNMENTS
|
||||
assert new_cp.current_course_index == 0
|
||||
assert new_cp.next_url is None
|
||||
assert new_cp.has_more is True
|
||||
|
||||
def test_course_404_advances_course_and_yields_failure(self) -> None:
|
||||
"""A 404 on a stage fetch means the whole course is inaccessible —
|
||||
yield a course-level ConnectorFailure and skip to the next course
|
||||
instead of burning API calls on every stage of a missing course."""
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1, 2],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_fetch_stage_page",
|
||||
side_effect=OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
"course gone",
|
||||
status_code_override=404,
|
||||
),
|
||||
):
|
||||
items, new_cp = _run_checkpoint(connector, cp)
|
||||
|
||||
expected_entity_id = "canvas-course-1"
|
||||
expected_next_course_index = 1
|
||||
assert len(items) == 1
|
||||
assert isinstance(items[0], ConnectorFailure)
|
||||
assert items[0].failed_entity is not None
|
||||
assert items[0].failed_entity.entity_id == expected_entity_id
|
||||
assert new_cp.current_course_index == expected_next_course_index
|
||||
assert new_cp.stage == CanvasStage.PAGES
|
||||
assert new_cp.next_url is None
|
||||
assert new_cp.has_more is True
|
||||
|
||||
def test_fatal_auth_failure_during_stage_fetch_propagates(self) -> None:
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
|
||||
with patch("onyx.connectors.canvas.client.rl_requests") as mock_requests:
|
||||
mock_requests.get.return_value = _mock_response(401, {})
|
||||
with pytest.raises(CredentialExpiredError):
|
||||
_run_checkpoint(connector, cp)
|
||||
|
||||
def test_security_failure_during_stage_fetch_propagates(self) -> None:
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_fetch_stage_page",
|
||||
side_effect=OnyxError(OnyxErrorCode.BAD_GATEWAY, "bad next link"),
|
||||
):
|
||||
with pytest.raises(OnyxError, match="bad next link"):
|
||||
_run_checkpoint(connector, cp)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_per_document_conversion_failure_yields_connector_failure(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Bad data for one page yields ConnectorFailure, doesn't stop processing."""
|
||||
bad_page = {
|
||||
"page_id": 10,
|
||||
"url": "test",
|
||||
"title": "Test",
|
||||
"body": None,
|
||||
"created_at": "2025-06-15T12:00:00Z",
|
||||
"updated_at": "bad-date",
|
||||
}
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(pages=[bad_page])
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
|
||||
items, new_cp = _run_checkpoint(connector, cp)
|
||||
|
||||
assert len(items) == 1
|
||||
assert isinstance(items[0], ConnectorFailure)
|
||||
assert new_cp.stage == CanvasStage.ASSIGNMENTS
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_all_courses_done_sets_has_more_false(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
mock_requests.get.side_effect = _make_url_dispatcher()
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True, course_ids=[1], current_course_index=1
|
||||
)
|
||||
|
||||
items, new_cp = _run_checkpoint(connector, cp)
|
||||
|
||||
assert items == []
|
||||
assert new_cp.has_more is False
|
||||
|
||||
def test_invalid_stage_raises_value_error(self) -> None:
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
cp.stage = "invalid" # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid checkpoint stage"):
|
||||
_run_checkpoint(connector, cp)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_from_checkpoint_with_perm_sync tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadFromCheckpointWithPermSync:
|
||||
@patch("onyx.connectors.canvas.connector.get_course_permissions")
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_documents_have_external_access(
|
||||
self, mock_requests: MagicMock, mock_perms: MagicMock
|
||||
) -> None:
|
||||
"""load_from_checkpoint_with_perm_sync attaches ExternalAccess to documents."""
|
||||
expected_access = ExternalAccess(
|
||||
external_user_emails={"student@school.edu"},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
mock_perms.return_value = expected_access
|
||||
mock_requests.get.side_effect = _make_url_dispatcher(
|
||||
pages=[_mock_page(10, "Syllabus", "2025-06-15T12:00:00Z")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
cp = CanvasConnectorCheckpoint(
|
||||
has_more=True,
|
||||
course_ids=[1],
|
||||
current_course_index=0,
|
||||
stage=CanvasStage.PAGES,
|
||||
)
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
gen = connector.load_from_checkpoint_with_perm_sync(start, end, cp)
|
||||
items: list[Document | HierarchyNode | ConnectorFailure] = []
|
||||
new_cp: CanvasConnectorCheckpoint | None = None
|
||||
try:
|
||||
while True:
|
||||
items.append(next(gen))
|
||||
except StopIteration as e:
|
||||
new_cp = e.value
|
||||
assert new_cp is not None
|
||||
|
||||
assert len(items) == 1
|
||||
assert isinstance(items[0], Document)
|
||||
assert items[0].external_access == expected_access
|
||||
assert new_cp.stage == CanvasStage.ASSIGNMENTS
|
||||
mock_perms.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper function tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseCanvasDt:
|
||||
def test_z_suffix_parsed_as_utc(self) -> None:
|
||||
result = _parse_canvas_dt("2025-06-15T12:00:00Z")
|
||||
|
||||
expected = datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_plus_offset_parsed_as_utc(self) -> None:
|
||||
result = _parse_canvas_dt("2025-06-15T12:00:00+00:00")
|
||||
|
||||
expected = datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
|
||||
assert result == expected
|
||||
|
||||
def test_result_is_timezone_aware(self) -> None:
|
||||
result = _parse_canvas_dt("2025-01-01T00:00:00Z")
|
||||
|
||||
assert result.tzinfo is not None
|
||||
|
||||
|
||||
class TestUnixToCanvasTime:
|
||||
def test_known_epoch_produces_expected_string(self) -> None:
|
||||
epoch = datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
result = _unix_to_canvas_time(epoch)
|
||||
|
||||
assert result == "2025-06-15T12:00:00Z"
|
||||
|
||||
def test_round_trips_with_parse_canvas_dt(self) -> None:
|
||||
epoch = datetime(2025, 3, 10, 8, 30, 0, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
result = _parse_canvas_dt(_unix_to_canvas_time(epoch))
|
||||
expected = datetime(2025, 3, 10, 8, 30, 0, tzinfo=timezone.utc)
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestInTimeWindow:
|
||||
def test_inside_window(self) -> None:
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
result = _in_time_window("2025-06-15T12:00:00Z", start, end)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_before_window(self) -> None:
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
result = _in_time_window("2025-05-01T12:00:00Z", start, end)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_after_window(self) -> None:
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
result = _in_time_window("2025-07-15T12:00:00Z", start, end)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_start_boundary_is_exclusive(self) -> None:
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
result = _in_time_window("2025-06-01T00:00:00Z", start, end)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_end_boundary_is_inclusive(self) -> None:
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
result = _in_time_window("2025-06-30T00:00:00Z", start, end)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestFetchStagePage:
|
||||
def test_uses_full_url_when_next_url_set(self) -> None:
|
||||
connector = _build_connector()
|
||||
with patch.object(
|
||||
connector.canvas_client, "get", return_value=([{"id": 1}], None)
|
||||
) as mock_get:
|
||||
result, next_url = connector._fetch_stage_page(
|
||||
next_url="https://myschool.instructure.com/api/v1/courses?page=2",
|
||||
endpoint="courses/1/pages",
|
||||
params={"per_page": "100"},
|
||||
)
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
full_url="https://myschool.instructure.com/api/v1/courses?page=2"
|
||||
)
|
||||
assert result == [{"id": 1}]
|
||||
|
||||
def test_uses_endpoint_and_params_when_no_next_url(self) -> None:
|
||||
connector = _build_connector()
|
||||
with patch.object(
|
||||
connector.canvas_client, "get", return_value=([{"id": 1}], None)
|
||||
) as mock_get:
|
||||
result, next_url = connector._fetch_stage_page(
|
||||
next_url=None,
|
||||
endpoint="courses/1/pages",
|
||||
params={"per_page": "100"},
|
||||
)
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
endpoint="courses/1/pages", params={"per_page": "100"}
|
||||
)
|
||||
|
||||
def test_returns_empty_list_for_none_response(self) -> None:
|
||||
connector = _build_connector()
|
||||
with patch.object(connector.canvas_client, "get", return_value=(None, None)):
|
||||
result, next_url = connector._fetch_stage_page(
|
||||
next_url=None,
|
||||
endpoint="courses/1/pages",
|
||||
params={},
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert next_url is None
|
||||
|
||||
|
||||
class TestProcessItems:
|
||||
def test_pages_in_window_converted(self) -> None:
|
||||
connector = _build_connector()
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
results, early_exit = connector._process_items(
|
||||
response=[_mock_page(10, "Syllabus", "2025-06-15T12:00:00Z")],
|
||||
stage=CanvasStage.PAGES,
|
||||
course_id=1,
|
||||
start=start,
|
||||
end=end,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], Document)
|
||||
assert early_exit is False
|
||||
|
||||
def test_pages_outside_window_skipped(self) -> None:
|
||||
connector = _build_connector()
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
results, early_exit = connector._process_items(
|
||||
response=[_mock_page(10, "Old", "2025-01-01T12:00:00Z")],
|
||||
stage=CanvasStage.PAGES,
|
||||
course_id=1,
|
||||
start=start,
|
||||
end=end,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
assert early_exit is True
|
||||
|
||||
def test_assignments_in_window_converted(self) -> None:
|
||||
connector = _build_connector()
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
results, early_exit = connector._process_items(
|
||||
response=[_mock_assignment(20, "HW1", 1, "2025-06-15T12:00:00Z")],
|
||||
stage=CanvasStage.ASSIGNMENTS,
|
||||
course_id=1,
|
||||
start=start,
|
||||
end=end,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], Document)
|
||||
assert early_exit is False
|
||||
|
||||
def test_announcements_in_window_converted(self) -> None:
|
||||
connector = _build_connector()
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
results, early_exit = connector._process_items(
|
||||
response=[_mock_announcement(30, "News", 1, "2025-06-15T12:00:00Z")],
|
||||
stage=CanvasStage.ANNOUNCEMENTS,
|
||||
course_id=1,
|
||||
start=start,
|
||||
end=end,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], Document)
|
||||
assert early_exit is False
|
||||
|
||||
def test_bad_item_yields_connector_failure(self) -> None:
|
||||
connector = _build_connector()
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
bad_page = {
|
||||
"page_id": 10,
|
||||
"url": "test",
|
||||
"title": "Test",
|
||||
"body": None,
|
||||
"created_at": "2025-06-15T12:00:00Z",
|
||||
"updated_at": "bad-date",
|
||||
}
|
||||
|
||||
results, early_exit = connector._process_items(
|
||||
response=[bad_page],
|
||||
stage=CanvasStage.PAGES,
|
||||
course_id=1,
|
||||
start=start,
|
||||
end=end,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], ConnectorFailure)
|
||||
|
||||
def test_page_early_exit_on_old_item(self) -> None:
|
||||
"""Pages sorted desc — item before start triggers early exit."""
|
||||
connector = _build_connector()
|
||||
start = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2025, 6, 30, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
results, early_exit = connector._process_items(
|
||||
response=[
|
||||
_mock_page(10, "New", "2025-06-15T12:00:00Z"),
|
||||
_mock_page(11, "Old", "2025-05-01T12:00:00Z"),
|
||||
_mock_page(12, "Older", "2025-04-01T12:00:00Z"),
|
||||
],
|
||||
stage=CanvasStage.PAGES,
|
||||
course_id=1,
|
||||
start=start,
|
||||
end=end,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert early_exit is True
|
||||
|
||||
|
||||
class TestMaybeAttachPermissions:
|
||||
def test_attaches_permissions_when_true(self) -> None:
|
||||
connector = _build_connector()
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.external_access = None
|
||||
expected_access = ExternalAccess(
|
||||
external_user_emails={"student@school.edu"},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
with patch.object(
|
||||
connector, "_get_course_permissions", return_value=expected_access
|
||||
):
|
||||
result = connector._maybe_attach_permissions(
|
||||
doc, course_id=1, include_permissions=True
|
||||
)
|
||||
|
||||
assert result.external_access == expected_access
|
||||
|
||||
def test_no_op_when_false(self) -> None:
|
||||
connector = _build_connector()
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.external_access = None
|
||||
|
||||
result = connector._maybe_attach_permissions(
|
||||
doc, course_id=1, include_permissions=False
|
||||
)
|
||||
|
||||
assert result.external_access is None
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Tests verifying that GithubConnector implements SlimConnector and SlimConnectorWithPermSync
|
||||
correctly, and that pruning uses the cheap slim path (no lazy loading).
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import SlimDocument
|
||||
|
||||
|
||||
def _make_pr(html_url: str) -> MagicMock:
|
||||
pr = MagicMock()
|
||||
pr.html_url = html_url
|
||||
pr.pull_request = None
|
||||
# commits and changed_files should never be accessed during slim retrieval
|
||||
type(pr).commits = PropertyMock(side_effect=AssertionError("lazy load triggered"))
|
||||
type(pr).changed_files = PropertyMock(
|
||||
side_effect=AssertionError("lazy load triggered")
|
||||
)
|
||||
return pr
|
||||
|
||||
|
||||
def _make_issue(html_url: str) -> MagicMock:
|
||||
issue = MagicMock()
|
||||
issue.html_url = html_url
|
||||
issue.pull_request = None
|
||||
return issue
|
||||
|
||||
|
||||
def _make_connector(include_issues: bool = False) -> GithubConnector:
|
||||
connector = GithubConnector(
|
||||
repo_owner="test-org",
|
||||
repositories="test-repo",
|
||||
include_prs=True,
|
||||
include_issues=include_issues,
|
||||
)
|
||||
connector.github_client = MagicMock()
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_deserialize_repository(mock_repo: MagicMock) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
"onyx.connectors.github.connector.deserialize_repository",
|
||||
return_value=mock_repo,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repo() -> MagicMock:
|
||||
repo = MagicMock()
|
||||
repo.name = "test-repo"
|
||||
repo.id = 123
|
||||
repo.raw_headers = {"x-github-request-id": "test"}
|
||||
repo.raw_data = {"id": 123, "name": "test-repo", "full_name": "test-org/test-repo"}
|
||||
prs = [
|
||||
_make_pr(f"https://github.com/test-org/test-repo/pull/{i}") for i in range(1, 4)
|
||||
]
|
||||
mock_paginated = MagicMock()
|
||||
mock_paginated.get_page.side_effect = lambda page: prs if page == 0 else []
|
||||
repo.get_pulls.return_value = mock_paginated
|
||||
return repo
|
||||
|
||||
|
||||
def test_github_connector_implements_slim_connector() -> None:
|
||||
connector = _make_connector()
|
||||
assert isinstance(connector, SlimConnector)
|
||||
|
||||
|
||||
def test_github_connector_implements_slim_connector_with_perm_sync() -> None:
|
||||
connector = _make_connector()
|
||||
assert isinstance(connector, SlimConnectorWithPermSync)
|
||||
|
||||
|
||||
def test_retrieve_all_slim_docs_returns_pr_urls(mock_repo: MagicMock) -> None:
|
||||
connector = _make_connector()
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
batches = list(connector.retrieve_all_slim_docs())
|
||||
|
||||
all_docs = [doc for batch in batches for doc in batch]
|
||||
assert len(all_docs) == 3
|
||||
assert all(isinstance(doc, SlimDocument) for doc in all_docs)
|
||||
assert {doc.id for doc in all_docs if isinstance(doc, SlimDocument)} == {
|
||||
"https://github.com/test-org/test-repo/pull/1",
|
||||
"https://github.com/test-org/test-repo/pull/2",
|
||||
"https://github.com/test-org/test-repo/pull/3",
|
||||
}
|
||||
|
||||
|
||||
def test_retrieve_all_slim_docs_has_no_external_access(mock_repo: MagicMock) -> None:
|
||||
"""Pruning does not need permissions — external_access should be None."""
|
||||
connector = _make_connector()
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
batches = list(connector.retrieve_all_slim_docs())
|
||||
|
||||
all_docs = [doc for batch in batches for doc in batch]
|
||||
assert all(doc.external_access is None for doc in all_docs)
|
||||
|
||||
|
||||
def test_retrieve_all_slim_docs_perm_sync_populates_external_access(
|
||||
mock_repo: MagicMock,
|
||||
) -> None:
|
||||
connector = _make_connector()
|
||||
mock_access = MagicMock(spec=ExternalAccess)
|
||||
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
with patch(
|
||||
"onyx.connectors.github.connector.get_external_access_permission",
|
||||
return_value=mock_access,
|
||||
) as mock_perm:
|
||||
batches = list(connector.retrieve_all_slim_docs_perm_sync())
|
||||
|
||||
# permission fetched at least once per repo (once per page in checkpoint-based flow)
|
||||
mock_perm.assert_called_with(mock_repo, connector.github_client)
|
||||
|
||||
all_docs = [doc for batch in batches for doc in batch]
|
||||
assert all(doc.external_access is mock_access for doc in all_docs)
|
||||
|
||||
|
||||
def test_retrieve_all_slim_docs_skips_pr_issues(mock_repo: MagicMock) -> None:
|
||||
"""Issues that are actually PRs should be skipped when include_issues=True."""
|
||||
connector = _make_connector(include_issues=True)
|
||||
|
||||
pr_issue = MagicMock()
|
||||
pr_issue.html_url = "https://github.com/test-org/test-repo/pull/99"
|
||||
pr_issue.pull_request = MagicMock() # non-None means it's a PR
|
||||
|
||||
real_issue = _make_issue("https://github.com/test-org/test-repo/issues/1")
|
||||
issues = [pr_issue, real_issue]
|
||||
mock_issues_paginated = MagicMock()
|
||||
mock_issues_paginated.get_page.side_effect = lambda page: (
|
||||
issues if page == 0 else []
|
||||
)
|
||||
mock_repo.get_issues.return_value = mock_issues_paginated
|
||||
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
batches = list(connector.retrieve_all_slim_docs())
|
||||
|
||||
issue_ids = {
|
||||
doc.id
|
||||
for batch in batches
|
||||
for doc in batch
|
||||
if isinstance(doc, SlimDocument) and "issues" in doc.id
|
||||
}
|
||||
assert issue_ids == {"https://github.com/test-org/test-repo/issues/1"}
|
||||
|
||||
|
||||
def test_pruning_routes_to_slim_connector_path(mock_repo: MagicMock) -> None:
|
||||
"""extract_ids_from_runnable_connector must use SlimConnector, not CheckpointedConnector."""
|
||||
connector = _make_connector()
|
||||
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
# If the CheckpointedConnector fallback were used instead, it would call
|
||||
# load_from_checkpoint which hits _convert_pr_to_document and lazy loads.
|
||||
# We verify the slim path is taken by checking load_from_checkpoint is NOT called.
|
||||
with patch.object(connector, "load_from_checkpoint") as mock_load:
|
||||
result = extract_ids_from_runnable_connector(connector)
|
||||
mock_load.assert_not_called()
|
||||
|
||||
assert len(result.raw_id_to_parent) == 3
|
||||
assert "https://github.com/test-org/test-repo/pull/1" in result.raw_id_to_parent
|
||||
@@ -0,0 +1,200 @@
|
||||
"""Unit tests for GoogleDriveConnector slim retrieval routing.
|
||||
|
||||
Verifies that:
|
||||
- GoogleDriveConnector implements SlimConnector so pruning takes the ID-only path
|
||||
- retrieve_all_slim_docs() calls _extract_slim_docs_from_google_drive with include_permissions=False
|
||||
- retrieve_all_slim_docs_perm_sync() calls _extract_slim_docs_from_google_drive with include_permissions=True
|
||||
- celery_utils routing picks retrieve_all_slim_docs() for GoogleDriveConnector
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
|
||||
def _make_done_checkpoint() -> GoogleDriveCheckpoint:
|
||||
return GoogleDriveCheckpoint(
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
|
||||
def _make_connector() -> GoogleDriveConnector:
|
||||
connector = GoogleDriveConnector(include_my_drives=True)
|
||||
connector._creds = MagicMock()
|
||||
connector._primary_admin_email = "admin@example.com"
|
||||
return connector
|
||||
|
||||
|
||||
class TestGoogleDriveSlimConnectorInterface:
|
||||
def test_implements_slim_connector(self) -> None:
|
||||
connector = _make_connector()
|
||||
assert isinstance(connector, SlimConnector)
|
||||
|
||||
def test_implements_slim_connector_with_perm_sync(self) -> None:
|
||||
connector = _make_connector()
|
||||
assert isinstance(connector, SlimConnectorWithPermSync)
|
||||
|
||||
def test_slim_connector_checked_before_perm_sync(self) -> None:
|
||||
"""SlimConnector must appear before SlimConnectorWithPermSync in MRO
|
||||
so celery_utils isinstance check routes to retrieve_all_slim_docs()."""
|
||||
mro = GoogleDriveConnector.__mro__
|
||||
slim_idx = mro.index(SlimConnector)
|
||||
perm_sync_idx = mro.index(SlimConnectorWithPermSync)
|
||||
assert slim_idx < perm_sync_idx
|
||||
|
||||
|
||||
class TestRetrieveAllSlimDocs:
|
||||
def test_does_not_call_extract_when_checkpoint_is_done(self) -> None:
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector, "build_dummy_checkpoint", return_value=_make_done_checkpoint()
|
||||
):
|
||||
with patch.object(
|
||||
connector,
|
||||
"_extract_slim_docs_from_google_drive",
|
||||
return_value=iter([[slim_doc]]),
|
||||
) as mock_extract:
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
|
||||
mock_extract.assert_not_called() # loop exits immediately since checkpoint is DONE
|
||||
|
||||
def test_calls_extract_with_include_permissions_false_non_done_checkpoint(
|
||||
self,
|
||||
) -> None:
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
# Checkpoint starts at START, _extract advances it to DONE
|
||||
with patch.object(connector, "build_dummy_checkpoint") as mock_build:
|
||||
start_checkpoint = GoogleDriveCheckpoint(
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.START,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=False,
|
||||
)
|
||||
mock_build.return_value = start_checkpoint
|
||||
|
||||
def _advance_checkpoint(**_kwargs: object) -> object:
|
||||
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
yield [slim_doc]
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_extract_slim_docs_from_google_drive",
|
||||
side_effect=_advance_checkpoint,
|
||||
) as mock_extract:
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
|
||||
mock_extract.assert_called_once()
|
||||
_, kwargs = mock_extract.call_args
|
||||
assert kwargs.get("include_permissions") is False
|
||||
|
||||
def test_yields_slim_documents(self) -> None:
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
start_checkpoint = GoogleDriveCheckpoint(
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.START,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector, "build_dummy_checkpoint", return_value=start_checkpoint
|
||||
):
|
||||
|
||||
def _advance_and_yield(**_kwargs: object) -> object:
|
||||
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
yield [slim_doc]
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_extract_slim_docs_from_google_drive",
|
||||
side_effect=_advance_and_yield,
|
||||
):
|
||||
batches = list(connector.retrieve_all_slim_docs())
|
||||
|
||||
assert len(batches) == 1
|
||||
assert batches[0][0] is slim_doc
|
||||
|
||||
|
||||
class TestRetrieveAllSlimDocsPermSync:
|
||||
def test_calls_extract_with_include_permissions_true(self) -> None:
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
start_checkpoint = GoogleDriveCheckpoint(
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.START,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector, "build_dummy_checkpoint", return_value=start_checkpoint
|
||||
):
|
||||
|
||||
def _advance_and_yield(**_kwargs: object) -> object:
|
||||
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
yield [slim_doc]
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_extract_slim_docs_from_google_drive",
|
||||
side_effect=_advance_and_yield,
|
||||
) as mock_extract:
|
||||
list(connector.retrieve_all_slim_docs_perm_sync())
|
||||
|
||||
mock_extract.assert_called_once()
|
||||
_, kwargs = mock_extract.call_args
|
||||
assert (
|
||||
kwargs.get("include_permissions") is None
|
||||
or kwargs.get("include_permissions") is True
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryUtilsRouting:
|
||||
def test_pruning_uses_retrieve_all_slim_docs(self) -> None:
|
||||
"""extract_ids_from_runnable_connector must call retrieve_all_slim_docs,
|
||||
not retrieve_all_slim_docs_perm_sync, for GoogleDriveConnector."""
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
with (
|
||||
patch.object(
|
||||
connector, "retrieve_all_slim_docs", return_value=iter([[slim_doc]])
|
||||
) as mock_slim,
|
||||
patch.object(
|
||||
connector, "retrieve_all_slim_docs_perm_sync"
|
||||
) as mock_perm_sync,
|
||||
):
|
||||
extract_ids_from_runnable_connector(
|
||||
connector, connector_type="google_drive"
|
||||
)
|
||||
|
||||
mock_slim.assert_called_once()
|
||||
mock_perm_sync.assert_not_called()
|
||||
@@ -0,0 +1,182 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from onyx.connectors.google_utils.google_kv import get_auth_url
|
||||
from onyx.connectors.google_utils.google_kv import get_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import get_service_account_key
|
||||
from onyx.connectors.google_utils.google_kv import upsert_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import upsert_service_account_key
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleAppWebCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
|
||||
|
||||
def _make_app_creds() -> GoogleAppCredentials:
|
||||
return GoogleAppCredentials(
|
||||
web=GoogleAppWebCredentials(
|
||||
client_id="client-id.apps.googleusercontent.com",
|
||||
project_id="test-project",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_secret="secret",
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
javascript_origins=["https://example.com"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _make_service_account_key() -> GoogleServiceAccountKey:
|
||||
return GoogleServiceAccountKey(
|
||||
type="service_account",
|
||||
project_id="test-project",
|
||||
private_key_id="private-key-id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
|
||||
client_email="test@test-project.iam.gserviceaccount.com",
|
||||
client_id="123",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test",
|
||||
universe_domain="googleapis.com",
|
||||
)
|
||||
|
||||
|
||||
def test_upsert_google_app_cred_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_google_app_cred(_make_app_creds(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["web"]["client_id"] == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
def test_upsert_service_account_key_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_service_account_key(_make_service_account_key(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["project_id"] == "test-project"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_google_app_cred_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload: dict[str, Any] = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
creds = get_google_app_cred(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert creds.web.client_id == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_service_account_key_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
stored_value: object = (
|
||||
_make_service_account_key().model_dump(mode="json")
|
||||
if not legacy_string
|
||||
else _make_service_account_key().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
key = get_service_account_key(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert key.client_email == "test@test-project.iam.gserviceaccount.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_auth_url_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
stored_state: dict[str, object] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored_state["key"] = key
|
||||
stored_state["value"] = value
|
||||
stored_state["encrypt"] = encrypt
|
||||
|
||||
class _StubFlow:
|
||||
def authorization_url(self, prompt: str) -> tuple[str, None]:
|
||||
assert prompt == "consent"
|
||||
return "https://accounts.google.com/o/oauth2/auth?state=test-state", None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
def _from_client_config(
|
||||
_app_config: object, *, scopes: object, redirect_uri: object
|
||||
) -> _StubFlow:
|
||||
del scopes, redirect_uri
|
||||
return _StubFlow()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.InstalledAppFlow.from_client_config",
|
||||
_from_client_config,
|
||||
)
|
||||
|
||||
auth_url = get_auth_url(42, DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert auth_url.startswith("https://accounts.google.com")
|
||||
assert stored_state["value"] == {"value": "test-state"}
|
||||
assert stored_state["encrypt"] is True
|
||||
86
backend/tests/unit/onyx/db/test_index_attempt_errors.py
Normal file
86
backend/tests/unit/onyx/db/test_index_attempt_errors.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Tests for get_index_attempt_errors_across_connectors."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_across_connectors
|
||||
from onyx.db.models import IndexAttemptError
|
||||
|
||||
|
||||
def _make_error(
|
||||
id: int = 1,
|
||||
cc_pair_id: int = 1,
|
||||
error_type: str | None = "TimeoutError",
|
||||
is_resolved: bool = False,
|
||||
) -> IndexAttemptError:
|
||||
"""Create a mock IndexAttemptError."""
|
||||
error = MagicMock(spec=IndexAttemptError)
|
||||
error.id = id
|
||||
error.connector_credential_pair_id = cc_pair_id
|
||||
error.error_type = error_type
|
||||
error.is_resolved = is_resolved
|
||||
return error
|
||||
|
||||
|
||||
class TestGetIndexAttemptErrorsAcrossConnectors:
|
||||
def test_returns_errors_and_count(self) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_errors = [_make_error(id=1), _make_error(id=2)]
|
||||
mock_session.scalar.return_value = 2
|
||||
mock_session.scalars.return_value.all.return_value = mock_errors
|
||||
|
||||
errors, total = get_index_attempt_errors_across_connectors(
|
||||
db_session=mock_session,
|
||||
)
|
||||
|
||||
assert total == 2
|
||||
assert len(errors) == 2
|
||||
|
||||
def test_returns_empty_when_no_errors(self) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = 0
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
errors, total = get_index_attempt_errors_across_connectors(
|
||||
db_session=mock_session,
|
||||
)
|
||||
|
||||
assert total == 0
|
||||
assert errors == []
|
||||
|
||||
def test_null_count_returns_zero(self) -> None:
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = None
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
errors, total = get_index_attempt_errors_across_connectors(
|
||||
db_session=mock_session,
|
||||
)
|
||||
|
||||
assert total == 0
|
||||
|
||||
def test_passes_filters_to_query(self) -> None:
|
||||
"""Verify that filter parameters result in .where() calls on the statement."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = 0
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 12, 31, tzinfo=timezone.utc)
|
||||
|
||||
# Should not raise — just verifying the function accepts all filter params
|
||||
get_index_attempt_errors_across_connectors(
|
||||
db_session=mock_session,
|
||||
cc_pair_id=42,
|
||||
error_type="TimeoutError",
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
unresolved_only=True,
|
||||
page=2,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# The function should have called scalar (for count) and scalars (for results)
|
||||
assert mock_session.scalar.called
|
||||
assert mock_session.scalars.called
|
||||
@@ -1,9 +1,13 @@
|
||||
import io
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
|
||||
from onyx.file_processing.extract_file_text import _clean_worksheet_matrix
|
||||
from onyx.file_processing.extract_file_text import _worksheet_to_matrix
|
||||
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
|
||||
|
||||
@@ -196,3 +200,182 @@ class TestXlsxToText:
|
||||
assert "r1c1" in lines[0] and "r1c2" in lines[0]
|
||||
assert "r2c1" in lines[1] and "r2c2" in lines[1]
|
||||
assert "r3c1" in lines[2] and "r3c2" in lines[2]
|
||||
|
||||
|
||||
class TestWorksheetToMatrixJaggedRows:
|
||||
"""openpyxl read_only mode can yield rows of differing widths when
|
||||
trailing cells are empty. The matrix must be padded to a rectangle
|
||||
so downstream column cleanup can index safely."""
|
||||
|
||||
def test_pads_shorter_trailing_rows(self) -> None:
|
||||
ws = MagicMock()
|
||||
ws.iter_rows.return_value = iter(
|
||||
[
|
||||
("A", "B", "C"),
|
||||
("X", "Y"),
|
||||
("P",),
|
||||
]
|
||||
)
|
||||
matrix = _worksheet_to_matrix(ws)
|
||||
assert matrix == [["A", "B", "C"], ["X", "Y", ""], ["P", "", ""]]
|
||||
|
||||
def test_pads_when_first_row_is_shorter(self) -> None:
|
||||
ws = MagicMock()
|
||||
ws.iter_rows.return_value = iter(
|
||||
[
|
||||
("A",),
|
||||
("X", "Y", "Z"),
|
||||
]
|
||||
)
|
||||
matrix = _worksheet_to_matrix(ws)
|
||||
assert matrix == [["A", "", ""], ["X", "Y", "Z"]]
|
||||
|
||||
def test_clean_worksheet_matrix_no_index_error_on_jagged_rows(self) -> None:
|
||||
"""Regression: previously raised IndexError when a later row was
|
||||
shorter than the first row and the out-of-range column on the
|
||||
first row was empty (so the short-circuit in `all()` did not
|
||||
save us)."""
|
||||
ws = MagicMock()
|
||||
ws.iter_rows.return_value = iter(
|
||||
[
|
||||
("A", "", "", "B"),
|
||||
("X", "Y"),
|
||||
]
|
||||
)
|
||||
matrix = _worksheet_to_matrix(ws)
|
||||
# Must not raise.
|
||||
cleaned = _clean_worksheet_matrix(matrix)
|
||||
assert cleaned == [["A", "", "", "B"], ["X", "Y", "", ""]]
|
||||
|
||||
|
||||
class TestXlsxSheetExtraction:
|
||||
def test_one_tuple_per_sheet(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Revenue": [["Month", "Amount"], ["Jan", "100"]],
|
||||
"Expenses": [["Category", "Cost"], ["Rent", "500"]],
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 2
|
||||
# Order preserved from workbook sheet order
|
||||
titles = [title for _csv, title in sheets]
|
||||
assert titles == ["Revenue", "Expenses"]
|
||||
# Content present in the right tuple
|
||||
revenue_csv, _ = sheets[0]
|
||||
expenses_csv, _ = sheets[1]
|
||||
assert "Month" in revenue_csv
|
||||
assert "Jan" in revenue_csv
|
||||
assert "Category" in expenses_csv
|
||||
assert "Rent" in expenses_csv
|
||||
|
||||
def test_tuple_structure_is_csv_text_then_title(self) -> None:
|
||||
"""The tuple order is (csv_text, sheet_title) — pin it so callers
|
||||
that unpack positionally don't silently break."""
|
||||
xlsx = _make_xlsx({"MySheet": [["a", "b"]]})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
csv_text, title = sheets[0]
|
||||
assert title == "MySheet"
|
||||
assert "a" in csv_text
|
||||
assert "b" in csv_text
|
||||
|
||||
def test_empty_sheet_is_skipped(self) -> None:
|
||||
"""A sheet whose CSV output is empty/whitespace-only should NOT
|
||||
appear in the result — the `if csv_text.strip():` guard filters
|
||||
it out."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Data": [["a", "b"]],
|
||||
"Empty": [],
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
assert sheets[0][1] == "Data"
|
||||
|
||||
def test_empty_workbook_returns_empty_list(self) -> None:
|
||||
"""All sheets empty → empty list (not a list of empty tuples)."""
|
||||
xlsx = _make_xlsx({"Sheet1": [], "Sheet2": []})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert sheets == []
|
||||
|
||||
def test_single_sheet(self) -> None:
|
||||
xlsx = _make_xlsx({"Only": [["x", "y"], ["1", "2"]]})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
csv_text, title = sheets[0]
|
||||
assert title == "Only"
|
||||
assert "x" in csv_text
|
||||
assert "1" in csv_text
|
||||
|
||||
def test_bad_zip_returns_empty_list(self) -> None:
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
sheets = xlsx_sheet_extraction(bad_file, file_name="test.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_bad_zip_tilde_file_returns_empty_list(self) -> None:
|
||||
"""`~$`-prefixed files are Excel lock files; failure should log
|
||||
at debug (not warning) and still return []."""
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
sheets = xlsx_sheet_extraction(bad_file, file_name="~$temp.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_csv_content_matches_xlsx_to_text_per_sheet(self) -> None:
|
||||
"""For a single-sheet workbook, xlsx_to_text output should equal
|
||||
the csv_text from xlsx_sheet_extraction — they share the same
|
||||
per-sheet CSV-ification logic."""
|
||||
single_sheet_data = [["Name", "Age"], ["Alice", "30"]]
|
||||
expected_text = xlsx_to_text(_make_xlsx({"People": single_sheet_data}))
|
||||
|
||||
sheets = xlsx_sheet_extraction(_make_xlsx({"People": single_sheet_data}))
|
||||
assert len(sheets) == 1
|
||||
csv_text, title = sheets[0]
|
||||
assert title == "People"
|
||||
assert csv_text.strip() == expected_text.strip()
|
||||
|
||||
def test_commas_in_cells_are_quoted(self) -> None:
|
||||
xlsx = _make_xlsx({"S1": [["hello, world", "normal"]]})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
csv_text, _ = sheets[0]
|
||||
assert '"hello, world"' in csv_text
|
||||
|
||||
def test_long_empty_row_run_capped_within_sheet(self) -> None:
|
||||
"""The matrix cleanup applies per-sheet: >2 empty rows collapse
|
||||
to 2, which keeps the sheet non-empty and it still appears in
|
||||
the result."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"S1": [
|
||||
["header"],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
["data"],
|
||||
]
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
csv_text, _ = sheets[0]
|
||||
lines = csv_text.strip().split("\n")
|
||||
# header + 2 empty (capped) + data = 4 lines
|
||||
assert len(lines) == 4
|
||||
assert "header" in lines[0]
|
||||
assert "data" in lines[-1]
|
||||
|
||||
def test_sheet_title_with_special_chars_preserved(self) -> None:
|
||||
"""Spaces, punctuation, unicode in sheet titles are preserved
|
||||
verbatim — the title is used as a link anchor downstream."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Q1 Revenue (USD)": [["a", "b"]],
|
||||
"Données": [["c", "d"]],
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
titles = [title for _csv, title in sheets]
|
||||
assert "Q1 Revenue (USD)" in titles
|
||||
assert "Données" in titles
|
||||
|
||||
558
backend/tests/unit/onyx/indexing/test_tabular_section_chunker.py
Normal file
558
backend/tests/unit/onyx/indexing/test_tabular_section_chunker.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""End-to-end tests for `TabularChunker.chunk_section`.
|
||||
|
||||
Each test is structured as:
|
||||
INPUT — the CSV text passed to the chunker + token budget + link
|
||||
EXPECTED — the exact chunk texts the chunker should emit
|
||||
ACT — a single call to `chunk_section`
|
||||
ASSERT — literal equality against the expected chunk texts
|
||||
|
||||
A character-level tokenizer (1 char == 1 token) is used so token-budget
|
||||
arithmetic is deterministic and expected chunks can be spelled out
|
||||
exactly.
|
||||
"""
|
||||
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.tabular_section_chunker import TabularChunker
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
|
||||
|
||||
class CharTokenizer(BaseTokenizer):
|
||||
def encode(self, string: str) -> list[int]:
|
||||
return [ord(c) for c in string]
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return list(string)
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return "".join(chr(t) for t in tokens)
|
||||
|
||||
|
||||
def _make_chunker() -> TabularChunker:
|
||||
return TabularChunker(tokenizer=CharTokenizer())
|
||||
|
||||
|
||||
_DEFAULT_LINK = "https://example.com/doc"
|
||||
|
||||
|
||||
def _tabular_section(
|
||||
text: str,
|
||||
link: str = _DEFAULT_LINK,
|
||||
heading: str | None = "sheet:Test",
|
||||
) -> Section:
|
||||
return TabularSection(text=text, link=link, heading=heading)
|
||||
|
||||
|
||||
class TestTabularChunkerChunkSection:
|
||||
def test_simple_csv_all_rows_fit_one_chunk(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
csv_text = "Name,Age,City\n" "Alice,30,NYC\n" "Bob,25,SF\n"
|
||||
heading = "sheet:People"
|
||||
content_token_limit = 500
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = [
|
||||
(
|
||||
"sheet:People\n"
|
||||
"Columns: Name, Age, City\n"
|
||||
"Name=Alice, Age=30, City=NYC\n"
|
||||
"Name=Bob, Age=25, City=SF"
|
||||
),
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
assert [p.is_continuation for p in out.payloads] == [False]
|
||||
assert all(p.links == {0: _DEFAULT_LINK} for p in out.payloads)
|
||||
assert out.accumulator.is_empty()
|
||||
|
||||
def test_overflow_splits_into_two_deterministic_chunks(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# prelude = "sheet:S\nColumns: col, val" (25 chars = 25 tokens)
|
||||
# At content_token_limit=57, row_budget = max(16, 57-31-1) = 25.
|
||||
# Each row "col=a, val=1" is 12 tokens; two rows + \n = 25 (fits),
|
||||
# three rows + 2×\n = 38 (overflows) → split after 2 rows.
|
||||
csv_text = "col,val\n" "a,1\n" "b,2\n" "c,3\n" "d,4\n"
|
||||
heading = "sheet:S"
|
||||
content_token_limit = 57
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = [
|
||||
("sheet:S\n" "Columns: col, val\n" "col=a, val=1\n" "col=b, val=2"),
|
||||
("sheet:S\n" "Columns: col, val\n" "col=c, val=3\n" "col=d, val=4"),
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
# First chunk is fresh; subsequent chunks mark as continuations.
|
||||
assert [p.is_continuation for p in out.payloads] == [False, True]
|
||||
# Link carries through every chunk.
|
||||
assert all(p.links == {0: _DEFAULT_LINK} for p in out.payloads)
|
||||
|
||||
# Add back in shortly
|
||||
# def test_header_only_csv_produces_single_prelude_chunk(self) -> None:
|
||||
# # --- INPUT -----------------------------------------------------
|
||||
# csv_text = "col1,col2\n"
|
||||
# link = "sheet:Headers"
|
||||
|
||||
# # --- EXPECTED --------------------------------------------------
|
||||
# expected_texts = [
|
||||
# "sheet:Headers\nColumns: col1, col2",
|
||||
# ]
|
||||
|
||||
# # --- ACT -------------------------------------------------------
|
||||
# out = _make_chunker().chunk_section(
|
||||
# _tabular_section(csv_text, link=link),
|
||||
# AccumulatorState(),
|
||||
# content_token_limit=500,
|
||||
# )
|
||||
|
||||
# # --- ASSERT ----------------------------------------------------
|
||||
# assert [p.text for p in out.payloads] == expected_texts
|
||||
|
||||
def test_empty_cells_dropped_from_chunk_text(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Alice's Age is empty; Bob's City is empty. Empty cells should
|
||||
# not appear as `field=` pairs in the output.
|
||||
csv_text = "Name,Age,City\n" "Alice,,NYC\n" "Bob,25,\n"
|
||||
heading = "sheet:P"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = [
|
||||
(
|
||||
"sheet:P\n"
|
||||
"Columns: Name, Age, City\n"
|
||||
"Name=Alice, City=NYC\n"
|
||||
"Name=Bob, Age=25"
|
||||
),
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=500,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
|
||||
def test_quoted_commas_in_csv_preserved_as_one_field(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# "Hello, world" is quoted in the CSV, so csv.reader parses it as
|
||||
# a single field. The surrounding quotes are stripped during
|
||||
# decoding, so the chunk text carries the bare value.
|
||||
csv_text = "Name,Notes\n" 'Alice,"Hello, world"\n'
|
||||
heading = "sheet:P"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = [
|
||||
("sheet:P\n" "Columns: Name, Notes\n" "Name=Alice, Notes=Hello, world"),
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=500,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
|
||||
def test_blank_rows_in_csv_are_skipped(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Stray blank rows in the CSV (e.g. export artifacts) shouldn't
|
||||
# produce ghost rows in the output.
|
||||
csv_text = "A,B\n" "\n" "1,2\n" "\n" "\n" "3,4\n"
|
||||
heading = "sheet:S"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = [
|
||||
("sheet:S\n" "Columns: A, B\n" "A=1, B=2\n" "A=3, B=4"),
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=500,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
|
||||
def test_accumulator_flushes_before_tabular_chunks(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# A text accumulator was populated by the prior text section.
|
||||
# Tabular sections are structural boundaries, so the pending
|
||||
# text is flushed as its own chunk before the tabular content.
|
||||
pending_text = "prior paragraph from an earlier text section"
|
||||
pending_link = "prev-link"
|
||||
|
||||
csv_text = "a,b\n" "1,2\n"
|
||||
heading = "sheet:S"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = [
|
||||
pending_text, # flushed accumulator
|
||||
("sheet:S\n" "Columns: a, b\n" "a=1, b=2"),
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(
|
||||
text=pending_text,
|
||||
link_offsets={0: pending_link},
|
||||
),
|
||||
content_token_limit=500,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
# Flushed chunk keeps the prior text's link; tabular chunk uses
|
||||
# the tabular section's link.
|
||||
assert out.payloads[0].links == {0: pending_link}
|
||||
assert out.payloads[1].links == {0: _DEFAULT_LINK}
|
||||
# Accumulator resets — tabular section is a structural boundary.
|
||||
assert out.accumulator.is_empty()
|
||||
|
||||
def test_multi_row_packing_under_budget_emits_single_chunk(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Three small rows (20 tokens each) under a generous
|
||||
# content_token_limit=100 should pack into ONE chunk — prelude
|
||||
# emitted once, rows stacked beneath it.
|
||||
csv_text = (
|
||||
"x\n" "aaaaaaaaaaaaaaaaaa\n" "bbbbbbbbbbbbbbbbbb\n" "cccccccccccccccccc\n"
|
||||
)
|
||||
heading = "S"
|
||||
content_token_limit = 100
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
# Each formatted row "x=<18-char value>" = 20 tokens.
|
||||
# Full chunk with sheet + Columns + 3 rows =
|
||||
# 1 + 1 + 10 + 1 + (20 + 1 + 20 + 1 + 20) = 75 tokens ≤ 100.
|
||||
# Single chunk carries all three rows.
|
||||
expected_texts = [
|
||||
"S\n"
|
||||
"Columns: x\n"
|
||||
"x=aaaaaaaaaaaaaaaaaa\n"
|
||||
"x=bbbbbbbbbbbbbbbbbb\n"
|
||||
"x=cccccccccccccccccc"
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
assert [p.is_continuation for p in out.payloads] == [False]
|
||||
assert all(len(p.text) <= content_token_limit for p in out.payloads)
|
||||
|
||||
def test_packing_reserves_prelude_budget_so_every_chunk_has_full_prelude(
|
||||
self,
|
||||
) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Budget (30) is large enough for all 5 bare rows (row_block =
|
||||
# 24 tokens) to pack as one chunk if the prelude were optional,
|
||||
# but [sheet] + Columns + 5_rows would be 41 tokens > 30. The
|
||||
# packing logic reserves space for the prelude: only 2 rows
|
||||
# pack per chunk (17 prelude overhead + 9 rows = 26 ≤ 30).
|
||||
# Every emitted chunk therefore carries its full prelude rather
|
||||
# than dropping Columns at emit time.
|
||||
csv_text = "x\n" "aa\n" "bb\n" "cc\n" "dd\n" "ee\n"
|
||||
heading = "S"
|
||||
content_token_limit = 30
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
# Prelude overhead = 'S\nColumns: x\n' = 1+1+10+1 = 13.
|
||||
# Each row "x=XX" = 4 tokens, row separator "\n" = 1.
|
||||
# 3 rows: 13 + (4+1+4+1+4) = 27 ≤ 30 ✓
|
||||
# 4 rows: 13 + (4+1+4+1+4+1+4) = 32 > 30 ✗
|
||||
# → 3 rows in the first chunk, 2 rows in the second.
|
||||
expected_texts = [
|
||||
"S\nColumns: x\nx=aa\nx=bb\nx=cc",
|
||||
"S\nColumns: x\nx=dd\nx=ee",
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
# Every chunk fits under the budget AND carries its full
|
||||
# prelude — that's the whole point of this check.
|
||||
assert all(len(p.text) <= content_token_limit for p in out.payloads)
|
||||
assert all("Columns: x" in p.text for p in out.payloads)
|
||||
|
||||
def test_oversized_row_splits_into_field_pieces_no_prelude(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Single-row CSV whose formatted form ("field 1=1, ..." = 53
|
||||
# tokens) exceeds content_token_limit (20). Per the chunker's
|
||||
# rules, oversized rows are split at field boundaries into
|
||||
# pieces each ≤ max_tokens, and no prelude is added to split
|
||||
# pieces (they already consume the full budget). A 53-token row
|
||||
# packs into 3 field-boundary pieces under a 20-token budget.
|
||||
csv_text = "field 1,field 2,field 3,field 4,field 5\n" "1,2,3,4,5\n"
|
||||
heading = "S"
|
||||
content_token_limit = 20
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
# Row = "field 1=1, field 2=2, field 3=3, field 4=4, field 5=5"
|
||||
# Fields @ 9 tokens each, ", " sep = 2 tokens.
|
||||
# "field 1=1, field 2=2" = 9+2+9 = 20 tokens ≤ 20 ✓
|
||||
# + ", field 3=3" = 20+2+9 = 31 > 20 → flush, start new
|
||||
# "field 3=3, field 4=4" = 9+2+9 = 20 ≤ 20 ✓
|
||||
# + ", field 5=5" = 20+2+9 = 31 > 20 → flush, start new
|
||||
# "field 5=5" = 9 ≤ 20 ✓
|
||||
# ceil(53 / 20) = 3 chunks.
|
||||
expected_texts = [
|
||||
"field 1=1, field 2=2",
|
||||
"field 3=3, field 4=4",
|
||||
"field 5=5",
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
# Invariant: no chunk exceeds max_tokens.
|
||||
assert all(len(p.text) <= content_token_limit for p in out.payloads)
|
||||
# is_continuation: first chunk False, rest True.
|
||||
assert [p.is_continuation for p in out.payloads] == [False, True, True]
|
||||
|
||||
def test_empty_tabular_section_flushes_accumulator_and_resets_it(
|
||||
self,
|
||||
) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Tabular sections are structural boundaries, so any pending text
|
||||
# buffer is flushed to a chunk before parsing the tabular content
|
||||
# — even if the tabular section itself is empty. The accumulator
|
||||
# is then reset.
|
||||
pending_text = "prior paragraph"
|
||||
pending_link_offsets = {0: "prev-link"}
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = [pending_text]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section("", heading="sheet:Empty"),
|
||||
AccumulatorState(
|
||||
text=pending_text,
|
||||
link_offsets=pending_link_offsets,
|
||||
),
|
||||
content_token_limit=500,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
assert out.accumulator.is_empty()
|
||||
|
||||
def test_single_oversized_field_token_splits_at_id_boundaries(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# A single `field=value` pair that itself exceeds max_tokens can't
|
||||
# be split at field boundaries — there's only one field. The
|
||||
# chunker falls back to encoding the pair to token ids and
|
||||
# slicing at max-token-sized windows.
|
||||
#
|
||||
# CSV has one column "x" with a 50-char value. Formatted pair =
|
||||
# "x=" + 50 a's = 52 tokens. Budget = 10.
|
||||
csv_text = "x\n" + ("a" * 50) + "\n"
|
||||
heading = "S"
|
||||
content_token_limit = 10
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
# 52-char pair at 10 tokens per window = 6 pieces:
|
||||
# [0:10) "x=aaaaaaaa" (10)
|
||||
# [10:20) "aaaaaaaaaa" (10)
|
||||
# [20:30) "aaaaaaaaaa" (10)
|
||||
# [30:40) "aaaaaaaaaa" (10)
|
||||
# [40:50) "aaaaaaaaaa" (10)
|
||||
# [50:52) "aa" (2)
|
||||
# Split pieces carry no prelude (they already consume the budget).
|
||||
expected_texts = [
|
||||
"x=aaaaaaaa",
|
||||
"aaaaaaaaaa",
|
||||
"aaaaaaaaaa",
|
||||
"aaaaaaaaaa",
|
||||
"aaaaaaaaaa",
|
||||
"aa",
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
# Every piece is ≤ max_tokens — the invariant the token-level
|
||||
# fallback exists to enforce.
|
||||
assert all(len(p.text) <= content_token_limit for p in out.payloads)
|
||||
|
||||
def test_underscored_column_gets_friendly_alias_in_parens(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Column headers with underscores get a space-substituted friendly
|
||||
# alias appended in parens on the `Columns:` line. Plain headers
|
||||
# pass through untouched.
|
||||
csv_text = "MTTR_hours,id,owner_name\n" "3,42,Alice\n"
|
||||
heading = "sheet:M"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = [
|
||||
(
|
||||
"sheet:M\n"
|
||||
"Columns: MTTR_hours (MTTR hours), id, owner_name (owner name)\n"
|
||||
"MTTR_hours=3, id=42, owner_name=Alice"
|
||||
),
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=500,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
|
||||
def test_oversized_row_between_small_rows_preserves_flanking_chunks(
|
||||
self,
|
||||
) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# State-machine check: small row, oversized row, small row. The
|
||||
# first small row should become a preluded chunk; the oversized
|
||||
# row flushes it and emits split fragments without prelude; then
|
||||
# the last small row picks up from wherever the split left off.
|
||||
#
|
||||
# Headers a,b,c,d. Row 1 and row 3 each have only column `a`
|
||||
# populated (tiny). Row 2 is a "fat" row with all four columns
|
||||
# populated.
|
||||
csv_text = "a,b,c,d\n" "1,,,\n" "xxx,yyy,zzz,www\n" "2,,,\n"
|
||||
heading = "S"
|
||||
content_token_limit = 20
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
# Prelude = 'S\nColumns: a, b, c, d\n' = 1+1+19+1 = 22 > 20, so
|
||||
# sheet fits with the row but full Columns header does not.
|
||||
# Row 1 formatted = "a=1" (3). build_chunk_from_scratch:
|
||||
# cols+row = 20+3 = 23 > 20 → skip cols. sheet+row = 1+1+3 = 5
|
||||
# ≤ 20 → chunk = "S\na=1".
|
||||
# Row 2 formatted = "a=xxx, b=yyy, c=zzz, d=www" (26 > 20) →
|
||||
# flush "S\na=1" and split at pair boundaries:
|
||||
# "a=xxx, b=yyy, c=zzz" (19 ≤ 20 ✓)
|
||||
# "d=www" (5)
|
||||
# Row 3 formatted = "a=2" (3). can_pack onto "d=www" (5):
|
||||
# 5 + 3 + 1 = 9 ≤ 20 ✓ → packs. Trailing fragment from the
|
||||
# split absorbs the next small row, which is the current v2
|
||||
# behavior (the fragment becomes `current_chunk` and the next
|
||||
# small row is appended with the standard packing rules).
|
||||
expected_texts = [
|
||||
"S\na=1",
|
||||
"a=xxx, b=yyy, c=zzz",
|
||||
"d=www\na=2",
|
||||
]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
assert all(len(p.text) <= content_token_limit for p in out.payloads)
|
||||
|
||||
def test_prelude_layering_column_header_fits_but_sheet_header_does_not(
|
||||
self,
|
||||
) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Budget lets `Columns: x\nx=y` fit but not the additional sheet
|
||||
# header on top. The chunker should add the column header and
|
||||
# drop the sheet header.
|
||||
#
|
||||
# sheet = "LongSheetName" (13), cols = "Columns: x" (10),
|
||||
# row = "x=y" (3). Budget = 15.
|
||||
# cols + row: 10+1+3 = 14 ≤ 15 ✓
|
||||
# sheet + cols + row: 13+1+10+1+3 = 28 > 15 ✗
|
||||
csv_text = "x\n" "y\n"
|
||||
heading = "LongSheetName"
|
||||
content_token_limit = 15
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = ["Columns: x\nx=y"]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
|
||||
def test_prelude_layering_sheet_header_fits_but_column_header_does_not(
|
||||
self,
|
||||
) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Budget is too small for the column header but leaves room for
|
||||
# the short sheet header. The chunker should fall back to just
|
||||
# sheet + row (its layered "try cols, then try sheet on top of
|
||||
# whatever we have" logic means sheet is attempted on the bare
|
||||
# row when cols didn't fit).
|
||||
#
|
||||
# sheet = "S" (1), cols = "Columns: ABC, DEF" (17),
|
||||
# row = "ABC=1, DEF=2" (12). Budget = 20.
|
||||
# cols + row: 17+1+12 = 30 > 20 ✗
|
||||
# sheet + row: 1+1+12 = 14 ≤ 20 ✓
|
||||
csv_text = "ABC,DEF\n" "1,2\n"
|
||||
heading = "S"
|
||||
content_token_limit = 20
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected_texts = ["S\nABC=1, DEF=2"]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker().chunk_section(
|
||||
_tabular_section(csv_text, heading=heading),
|
||||
AccumulatorState(),
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Unit tests for MinimalPersonaSnapshot.from_model knowledge_sources aggregation."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.server.features.document_set.models import DocumentSetSummary
|
||||
from onyx.server.features.persona.models import MinimalPersonaSnapshot
|
||||
|
||||
|
||||
_STUB_DS_SUMMARY = DocumentSetSummary(
|
||||
id=1,
|
||||
name="stub",
|
||||
description=None,
|
||||
cc_pair_summaries=[],
|
||||
is_up_to_date=True,
|
||||
is_public=True,
|
||||
users=[],
|
||||
groups=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_persona(**overrides: object) -> MagicMock:
|
||||
"""Build a mock Persona with sensible defaults.
|
||||
|
||||
Every relationship defaults to empty so tests only need to set the
|
||||
fields they care about.
|
||||
"""
|
||||
p = MagicMock()
|
||||
p.id = 1
|
||||
p.name = "test"
|
||||
p.description = ""
|
||||
p.tools = []
|
||||
p.starter_messages = None
|
||||
p.document_sets = []
|
||||
p.hierarchy_nodes = []
|
||||
p.attached_documents = []
|
||||
p.user_files = []
|
||||
p.llm_model_version_override = None
|
||||
p.llm_model_provider_override = None
|
||||
p.uploaded_image_id = None
|
||||
p.icon_name = None
|
||||
p.is_public = True
|
||||
p.is_listed = True
|
||||
p.display_priority = None
|
||||
p.is_featured = False
|
||||
p.builtin_persona = False
|
||||
p.labels = []
|
||||
p.user = None
|
||||
|
||||
for k, v in overrides.items():
|
||||
setattr(p, k, v)
|
||||
return p
|
||||
|
||||
|
||||
def _make_cc_pair(source: DocumentSource) -> MagicMock:
|
||||
cc = MagicMock()
|
||||
cc.connector.source = source
|
||||
cc.name = source.value
|
||||
cc.id = 1
|
||||
cc.access_type = "PUBLIC"
|
||||
return cc
|
||||
|
||||
|
||||
def _make_doc_set(
|
||||
cc_pairs: list[MagicMock] | None = None,
|
||||
fed_connectors: list[MagicMock] | None = None,
|
||||
) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = 1
|
||||
ds.name = "ds"
|
||||
ds.description = None
|
||||
ds.is_up_to_date = True
|
||||
ds.is_public = True
|
||||
ds.users = []
|
||||
ds.groups = []
|
||||
ds.connector_credential_pairs = cc_pairs or []
|
||||
ds.federated_connectors = fed_connectors or []
|
||||
return ds
|
||||
|
||||
|
||||
def _make_federated_ds_mapping(
|
||||
source: FederatedConnectorSource,
|
||||
) -> MagicMock:
|
||||
mapping = MagicMock()
|
||||
mapping.federated_connector.source = source
|
||||
mapping.federated_connector_id = 1
|
||||
mapping.entities = {}
|
||||
return mapping
|
||||
|
||||
|
||||
def _make_hierarchy_node(source: DocumentSource) -> MagicMock:
|
||||
node = MagicMock()
|
||||
node.source = source
|
||||
return node
|
||||
|
||||
|
||||
def _make_attached_document(source: DocumentSource) -> MagicMock:
|
||||
doc = MagicMock()
|
||||
doc.parent_hierarchy_node = MagicMock()
|
||||
doc.parent_hierarchy_node.source = source
|
||||
return doc
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_empty_persona_has_no_knowledge_sources(_mock_ds: MagicMock) -> None:
|
||||
persona = _make_persona()
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert snapshot.knowledge_sources == []
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_user_files_adds_user_file_source(_mock_ds: MagicMock) -> None:
|
||||
persona = _make_persona(user_files=[MagicMock()])
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert DocumentSource.USER_FILE in snapshot.knowledge_sources
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_no_user_files_excludes_user_file_source(_mock_ds: MagicMock) -> None:
|
||||
cc = _make_cc_pair(DocumentSource.CONFLUENCE)
|
||||
ds = _make_doc_set(cc_pairs=[cc])
|
||||
persona = _make_persona(document_sets=[ds])
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert DocumentSource.USER_FILE not in snapshot.knowledge_sources
|
||||
assert DocumentSource.CONFLUENCE in snapshot.knowledge_sources
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_federated_connector_in_doc_set(_mock_ds: MagicMock) -> None:
|
||||
fed = _make_federated_ds_mapping(FederatedConnectorSource.FEDERATED_SLACK)
|
||||
ds = _make_doc_set(fed_connectors=[fed])
|
||||
persona = _make_persona(document_sets=[ds])
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert DocumentSource.SLACK in snapshot.knowledge_sources
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_hierarchy_nodes_and_attached_documents(_mock_ds: MagicMock) -> None:
|
||||
node = _make_hierarchy_node(DocumentSource.GOOGLE_DRIVE)
|
||||
doc = _make_attached_document(DocumentSource.SHAREPOINT)
|
||||
persona = _make_persona(hierarchy_nodes=[node], attached_documents=[doc])
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert DocumentSource.GOOGLE_DRIVE in snapshot.knowledge_sources
|
||||
assert DocumentSource.SHAREPOINT in snapshot.knowledge_sources
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_all_source_types_combined(_mock_ds: MagicMock) -> None:
|
||||
cc = _make_cc_pair(DocumentSource.CONFLUENCE)
|
||||
fed = _make_federated_ds_mapping(FederatedConnectorSource.FEDERATED_SLACK)
|
||||
ds = _make_doc_set(cc_pairs=[cc], fed_connectors=[fed])
|
||||
node = _make_hierarchy_node(DocumentSource.GOOGLE_DRIVE)
|
||||
doc = _make_attached_document(DocumentSource.SHAREPOINT)
|
||||
persona = _make_persona(
|
||||
document_sets=[ds],
|
||||
hierarchy_nodes=[node],
|
||||
attached_documents=[doc],
|
||||
user_files=[MagicMock()],
|
||||
)
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
sources = set(snapshot.knowledge_sources)
|
||||
assert sources == {
|
||||
DocumentSource.CONFLUENCE,
|
||||
DocumentSource.SLACK,
|
||||
DocumentSource.GOOGLE_DRIVE,
|
||||
DocumentSource.SHAREPOINT,
|
||||
DocumentSource.USER_FILE,
|
||||
}
|
||||
@@ -100,6 +100,39 @@ class TestGenerateOllamaDisplayName:
|
||||
result = generate_ollama_display_name("llama3.3:70b")
|
||||
assert "3.3" in result or "3 3" in result # Either format is acceptable
|
||||
|
||||
def test_non_size_tag_shown(self) -> None:
|
||||
"""Test that non-size tags like 'e4b' are included in the display name."""
|
||||
result = generate_ollama_display_name("gemma4:e4b")
|
||||
assert "Gemma" in result
|
||||
assert "4" in result
|
||||
assert "E4B" in result
|
||||
|
||||
def test_size_with_cloud_modifier(self) -> None:
|
||||
"""Test size tag with cloud modifier."""
|
||||
result = generate_ollama_display_name("deepseek-v3.1:671b-cloud")
|
||||
assert "DeepSeek" in result
|
||||
assert "671B" in result
|
||||
assert "Cloud" in result
|
||||
|
||||
def test_size_with_multiple_modifiers(self) -> None:
|
||||
"""Test size tag with multiple modifiers."""
|
||||
result = generate_ollama_display_name("qwen3-vl:235b-instruct-cloud")
|
||||
assert "Qwen" in result
|
||||
assert "235B" in result
|
||||
assert "Instruct" in result
|
||||
assert "Cloud" in result
|
||||
|
||||
def test_quantization_tag_shown(self) -> None:
|
||||
"""Test that quantization tags are included in the display name."""
|
||||
result = generate_ollama_display_name("llama3:q4_0")
|
||||
assert "Llama" in result
|
||||
assert "Q4_0" in result
|
||||
|
||||
def test_cloud_only_tag(self) -> None:
|
||||
"""Test standalone cloud tag."""
|
||||
result = generate_ollama_display_name("glm-4.6:cloud")
|
||||
assert "CLOUD" in result
|
||||
|
||||
|
||||
class TestStripOpenrouterVendorPrefix:
|
||||
"""Tests for OpenRouter vendor prefix stripping."""
|
||||
|
||||
@@ -95,9 +95,9 @@ class TestForceAddSearchToolGuard:
|
||||
without a vector DB."""
|
||||
import inspect
|
||||
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import _construct_tools_impl
|
||||
|
||||
source = inspect.getsource(construct_tools)
|
||||
source = inspect.getsource(_construct_tools_impl)
|
||||
assert (
|
||||
"DISABLE_VECTOR_DB" in source
|
||||
), "construct_tools should reference DISABLE_VECTOR_DB to suppress force-adding SearchTool"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user