mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-15 11:52:53 +00:00
Compare commits
71 Commits
nightly-la
...
bo/pruning
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a15ac7bda | ||
|
|
20c5107ba6 | ||
|
|
357bc91aee | ||
|
|
09653872a2 | ||
|
|
ff01a53f83 | ||
|
|
03ddd5ca9b | ||
|
|
8c49e4573c | ||
|
|
f1696ffa16 | ||
|
|
a427cb5b0c | ||
|
|
f7e4be18dd | ||
|
|
0f31c490fa | ||
|
|
c9a4a6e42b | ||
|
|
558c9df3c7 | ||
|
|
30003036d3 | ||
|
|
4b2f18c239 | ||
|
|
4290b097f5 | ||
|
|
b0f621a08b | ||
|
|
112edf41c5 | ||
|
|
74eb1d7212 | ||
|
|
e62d592b11 | ||
|
|
57a0d25321 | ||
|
|
887f79d7a5 | ||
|
|
65fd1c3ec8 | ||
|
|
6e3ee287b9 | ||
|
|
dee0b7867e | ||
|
|
77beb8044e | ||
|
|
750d3ac4ed | ||
|
|
6c02087ba4 | ||
|
|
0425283ed0 | ||
|
|
da97a57c58 | ||
|
|
8087ddb97c | ||
|
|
d9d5943dc4 | ||
|
|
97a7fa6f7f | ||
|
|
8027e62446 | ||
|
|
571e860d4f | ||
|
|
89b91ac384 | ||
|
|
069b1f3efb | ||
|
|
ef2fffcd6e | ||
|
|
925be18424 | ||
|
|
38fffc8ad8 | ||
|
|
3e9e2f08d5 | ||
|
|
243d93ecd8 | ||
|
|
4effe77225 | ||
|
|
ef2df458a3 | ||
|
|
d3000da3d0 | ||
|
|
a5c703f9ca | ||
|
|
d10c901c43 | ||
|
|
f1ac555c57 | ||
|
|
ed52384c21 | ||
|
|
cb10376a0d | ||
|
|
5a25b70b9c | ||
|
|
8cbc37f281 | ||
|
|
9d78f71f23 | ||
|
|
fbf3179d84 | ||
|
|
779470b553 | ||
|
|
151e189898 | ||
|
|
72e08f81a4 | ||
|
|
65792a8ad8 | ||
|
|
497b700b3d | ||
|
|
c3ed2135f1 | ||
|
|
a969d56818 | ||
|
|
a31d862f48 | ||
|
|
a4e6d4cf43 | ||
|
|
1e6f94e00d | ||
|
|
a769b87a9d | ||
|
|
278fc7e9b1 | ||
|
|
eb34df470f | ||
|
|
9d1785273f | ||
|
|
ef69b17d26 | ||
|
|
787c961802 | ||
|
|
62bc4fa2a3 |
@@ -2,6 +2,7 @@ FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0a
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
default-jre \
|
||||
fd-find \
|
||||
fzf \
|
||||
git \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
|
||||
"mounts": [
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
|
||||
4
.github/workflows/deployment.yml
vendored
4
.github/workflows/deployment.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -471,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -710,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
2
.github/workflows/release-cli.yml
vendored
2
.github/workflows/release-cli.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -1,64 +1,57 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": ["logic", "syntax", "style"],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": ["dependabot[bot]", "renovate[bot]"],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": false,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -49,12 +49,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Tasks: vespa metadata sync, connector deletion, doc permissions upsert, checkpoint cleanup, index attempt cleanup
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Tasks: connector pruning, document permissions sync, external group sync, CSV generation
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
|
||||
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Onyx which \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Base stage with dependencies
|
||||
FROM python:3.11.7-slim-bookworm AS base
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47 AS base
|
||||
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
@@ -208,7 +208,7 @@ def do_run_migrations(
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -380,7 +380,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -421,7 +421,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -464,7 +464,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
@@ -25,7 +25,7 @@ def upgrade() -> None:
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
batch_op.alter_column(
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
@@ -71,7 +71,7 @@ def downgrade() -> None:
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
batch_op.alter_column(
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
|
||||
@@ -63,7 +63,7 @@ def upgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
existing_server_default=sa.text("now()"),
|
||||
)
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
@@ -85,7 +85,7 @@ def downgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
existing_server_default=sa.text("now()"),
|
||||
)
|
||||
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
|
||||
op.drop_table("accesstoken")
|
||||
|
||||
@@ -19,7 +19,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
sequence = Sequence("connector_credential_pair_id_seq")
|
||||
op.execute(CreateSequence(sequence)) # type: ignore
|
||||
op.execute(CreateSequence(sequence))
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""add_error_tracking_fields_to_index_attempt_errors
|
||||
|
||||
Revision ID: d129f37b3d87
|
||||
Revises: 503883791c39
|
||||
Create Date: 2026-04-06 19:11:18.261800
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d129f37b3d87"
|
||||
down_revision = "503883791c39"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt_errors",
|
||||
sa.Column("error_type", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("index_attempt_errors", "error_type")
|
||||
@@ -49,7 +49,7 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
@@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
target_metadata=target_metadata,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -96,11 +96,14 @@ def get_model_app() -> FastAPI:
|
||||
title="Onyx Model Server", version=__version__, lifespan=lifespan
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -10,6 +10,7 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
@@ -62,11 +63,14 @@ logger = setup_logger()
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -94,6 +98,17 @@ class TenantAwareTask(Task):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@before_task_publish.connect
|
||||
def on_before_task_publish(
|
||||
headers: dict[str, Any] | None = None,
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Stamp the current wall-clock time into the task message headers so that
|
||||
workers can compute queue wait time (time between publish and execution)."""
|
||||
if headers is not None:
|
||||
headers["enqueued_at"] = time.time()
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None, # noqa: ARG001
|
||||
|
||||
@@ -16,6 +16,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -36,6 +42,7 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -50,6 +57,31 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -90,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("light")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -59,6 +59,11 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
@@ -102,7 +107,7 @@ def revoke_tasks_blocking_deletion(
|
||||
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
|
||||
try:
|
||||
prune_payload = redis_connector.prune.payload
|
||||
@@ -110,7 +115,7 @@ def revoke_tasks_blocking_deletion(
|
||||
app.control.revoke(prune_payload.celery_task_id)
|
||||
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
|
||||
try:
|
||||
external_group_sync_payload = redis_connector.external_group_sync.payload
|
||||
@@ -300,6 +305,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
inc_deletion_blocked(tenant_id, "indexing")
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -307,11 +313,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
inc_deletion_blocked(tenant_id, "pruning")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
inc_deletion_blocked(tenant_id, "permissions")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -359,6 +367,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
inc_deletion_started(tenant_id)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -508,7 +517,11 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
if not connector:
|
||||
task_logger.info(
|
||||
"Connector deletion - Connector already deleted, skipping connector cleanup"
|
||||
)
|
||||
elif not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
@@ -523,6 +536,12 @@ def monitor_connector_deletion_taskset(
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "success", duration)
|
||||
inc_deletion_completed(tenant_id, "success")
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
@@ -541,6 +560,11 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "failure", duration)
|
||||
inc_deletion_completed(tenant_id, "failure")
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
@@ -717,5 +741,6 @@ def validate_connector_deletion_fence(
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
inc_deletion_fence_reset(tenant_id)
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -135,10 +135,13 @@ def _docfetching_task(
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -172,6 +172,10 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
task_logger.debug(
|
||||
"Verified tenant info, migration record, and search settings."
|
||||
)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
@@ -325,6 +329,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
task_logger.debug("Released the OpenSearch migration lock.")
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the migration task."
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -525,6 +526,14 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# Session 1: pre-enumeration — load cc_pair and instantiate the connector.
|
||||
# The session is closed before enumeration so the DB connection is not held
|
||||
# open during the 10–30+ minute connector crawl.
|
||||
connector_source: DocumentSource | None = None
|
||||
connector_type: str = ""
|
||||
is_connector_public: bool = False
|
||||
runnable_connector: BaseConnector | None = None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -550,49 +559,51 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
connector_source = cc_pair.connector.source
|
||||
connector_type = connector_source.value
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}"
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={connector_source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
connector_source,
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
# Session 1 closed here — connection released before enumeration.
|
||||
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# Extract docs and hierarchy nodes from the source
|
||||
connector_type = cc_pair.connector.source.value
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
# Extract docs and hierarchy nodes from the source (no DB session held).
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
# Session 2: post-enumeration — hierarchy upserts, diff computation, task dispatch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
source = connector_source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=source,
|
||||
commit=True,
|
||||
commit=False,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
@@ -601,9 +612,13 @@ def connector_pruning_generator_task(
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=True,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# Single commit so the FK reference in the join table can never
|
||||
# outrun the parent hierarchy_node insert.
|
||||
db_session.commit()
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -658,7 +673,7 @@ def connector_pruning_generator_task(
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"connector_source={connector_source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
|
||||
index_attempt_id: int
|
||||
|
||||
error_type: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
|
||||
return cls(
|
||||
@@ -37,4 +39,5 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
is_resolved=model.is_resolved,
|
||||
time_created=model.time_created,
|
||||
index_attempt_id=model.index_attempt_id,
|
||||
error_type=model.error_type,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -556,6 +557,27 @@ def connector_document_extraction(
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
if failure.exception is not None:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "connector_fetch")
|
||||
scope.set_tag("connector_source", db_connector.source.value)
|
||||
scope.set_tag("cc_pair_id", str(cc_pair_id))
|
||||
scope.set_tag("index_attempt_id", str(index_attempt_id))
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
if failure.failed_document:
|
||||
scope.set_tag(
|
||||
"doc_id", failure.failed_document.document_id
|
||||
)
|
||||
if failure.failed_entity:
|
||||
scope.set_tag(
|
||||
"entity_id", failure.failed_entity.entity_id
|
||||
)
|
||||
scope.fingerprint = [
|
||||
"connector-fetch-failure",
|
||||
db_connector.source.value,
|
||||
type(failure.exception).__name__,
|
||||
]
|
||||
sentry_sdk.capture_exception(failure.exception)
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_index_attempt_error(
|
||||
|
||||
@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
logger.info(f"Cache miss for file with id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
|
||||
@@ -4,8 +4,6 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -635,7 +633,6 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1020,20 +1017,16 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
memory = update_memory_at_index(
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
memory = add_memory(
|
||||
persisted_memory_id = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -67,7 +67,6 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -1006,93 +1005,86 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
|
||||
48
backend/onyx/configs/sentry.py
Normal file
48
backend/onyx/configs/sentry.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Any
|
||||
|
||||
from sentry_sdk.types import Event
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_instance_id_resolved = False
|
||||
|
||||
|
||||
def _add_instance_tags(
|
||||
event: Event,
|
||||
hint: dict[str, Any], # noqa: ARG001
|
||||
) -> Event | None:
|
||||
"""Sentry before_send hook that lazily attaches instance identification tags.
|
||||
|
||||
On the first event, resolves the instance UUID from the KV store (requires DB)
|
||||
and sets it as a global Sentry tag. Subsequent events pick it up automatically.
|
||||
"""
|
||||
global _instance_id_resolved
|
||||
|
||||
if _instance_id_resolved:
|
||||
return event
|
||||
|
||||
try:
|
||||
import sentry_sdk
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
if MULTI_TENANT:
|
||||
instance_id = "multi-tenant-cloud"
|
||||
else:
|
||||
from onyx.utils.telemetry import get_or_generate_uuid
|
||||
|
||||
instance_id = get_or_generate_uuid()
|
||||
|
||||
sentry_sdk.set_tag("instance_id", instance_id)
|
||||
|
||||
# Also set on this event since set_tag won't retroactively apply
|
||||
event.setdefault("tags", {})["instance_id"] = instance_id
|
||||
|
||||
# Only mark resolved after success — if DB wasn't ready, retry next event
|
||||
_instance_id_resolved = True
|
||||
except Exception:
|
||||
logger.debug("Failed to resolve instance_id for Sentry tagging")
|
||||
|
||||
return event
|
||||
@@ -171,7 +171,10 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
document.metadata[extra_field] = task[extra_field]
|
||||
|
||||
if self.retrieve_task_comments:
|
||||
document.sections.extend(self._get_task_comments(task["id"]))
|
||||
document.sections = [
|
||||
*document.sections,
|
||||
*self._get_task_comments(task["id"]),
|
||||
]
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
|
||||
@@ -61,6 +61,9 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 5
|
||||
|
||||
_SERVER_ERROR_CODES = {500, 502, 503, 504}
|
||||
|
||||
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
|
||||
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
|
||||
@@ -569,7 +572,8 @@ class OnyxConfluence:
|
||||
if not limit:
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(limit))
|
||||
current_limit = limit
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(current_limit))
|
||||
|
||||
while url_suffix:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
@@ -609,40 +613,61 @@ class OnyxConfluence:
|
||||
)
|
||||
continue
|
||||
|
||||
# If we fail due to a 500, try one by one.
|
||||
# NOTE: this iterative approach only works for server, since cloud uses cursor-based
|
||||
# pagination
|
||||
if raw_response.status_code == 500 and not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
if initial_start is None:
|
||||
# can't handle this if we don't have offset-based pagination
|
||||
raise
|
||||
if raw_response.status_code in _SERVER_ERROR_CODES:
|
||||
# Try reducing the page size -- Confluence often times out
|
||||
# on large result sets (especially Cloud 504s).
|
||||
if current_limit > _MINIMUM_PAGINATION_LIMIT:
|
||||
old_limit = current_limit
|
||||
current_limit = max(
|
||||
current_limit // 2, _MINIMUM_PAGINATION_LIMIT
|
||||
)
|
||||
logger.warning(
|
||||
f"Confluence returned {raw_response.status_code}. "
|
||||
f"Reducing limit from {old_limit} to {current_limit} "
|
||||
f"and retrying."
|
||||
)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
continue
|
||||
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=limit,
|
||||
)
|
||||
# Limit reduction exhausted -- for Server, fall back to
|
||||
# one-by-one offset pagination as a last resort.
|
||||
if not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = (
|
||||
yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=current_limit,
|
||||
)
|
||||
)
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
|
||||
else:
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
f"Error in confluence call to {url_suffix} "
|
||||
f"after reducing limit to {current_limit}.\n"
|
||||
f"Raw Response Text: {raw_response.text}\n"
|
||||
f"Error: {e}\n"
|
||||
)
|
||||
raise
|
||||
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
next_response = raw_response.json()
|
||||
except Exception as e:
|
||||
@@ -680,6 +705,10 @@ class OnyxConfluence:
|
||||
old_url_suffix = url_suffix
|
||||
updated_start = get_start_param_from_url(old_url_suffix)
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
if url_suffix and current_limit != limit:
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
for i, result in enumerate(results):
|
||||
updated_start += 1
|
||||
if url_suffix and next_page_callback and i == len(results) - 1:
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import csv
|
||||
import io
|
||||
from typing import IO
|
||||
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.file_processing.extract_file_text import file_io_to_text
|
||||
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_tabular_file(file_name: str) -> bool:
|
||||
lowered = file_name.lower()
|
||||
return any(lowered.endswith(ext) for ext in OnyxFileExtensions.TABULAR_EXTENSIONS)
|
||||
|
||||
|
||||
def _tsv_to_csv(tsv_text: str) -> str:
|
||||
"""Re-serialize tab-separated text as CSV so downstream parsers that
|
||||
assume the default Excel dialect read the columns correctly."""
|
||||
out = io.StringIO()
|
||||
csv.writer(out, lineterminator="\n").writerows(
|
||||
csv.reader(io.StringIO(tsv_text), dialect="excel-tab")
|
||||
)
|
||||
return out.getvalue().rstrip("\n")
|
||||
|
||||
|
||||
def tabular_file_to_sections(
|
||||
file: IO[bytes],
|
||||
file_name: str,
|
||||
link: str = "",
|
||||
) -> list[TabularSection]:
|
||||
"""Convert a tabular file into one or more TabularSections.
|
||||
|
||||
- .xlsx → one TabularSection per non-empty sheet.
|
||||
- .csv / .tsv → a single TabularSection containing the full decoded
|
||||
file.
|
||||
|
||||
Returns an empty list when the file yields no extractable content.
|
||||
"""
|
||||
lowered = file_name.lower()
|
||||
|
||||
if lowered.endswith(".xlsx"):
|
||||
return [
|
||||
TabularSection(link=f"sheet:{sheet_title}", text=csv_text)
|
||||
for csv_text, sheet_title in xlsx_sheet_extraction(
|
||||
file, file_name=file_name
|
||||
)
|
||||
]
|
||||
|
||||
if not lowered.endswith((".csv", ".tsv")):
|
||||
raise ValueError(f"{file_name!r} is not a tabular file")
|
||||
|
||||
try:
|
||||
text = file_io_to_text(file).strip()
|
||||
except Exception:
|
||||
logger.exception(f"Failure decoding {file_name}")
|
||||
raise
|
||||
|
||||
if not text:
|
||||
return []
|
||||
if lowered.endswith(".tsv"):
|
||||
text = _tsv_to_csv(text)
|
||||
return [TabularSection(link=link or file_name, text=text)]
|
||||
@@ -35,10 +35,16 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import IndexingHeartbeatInterface
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -427,7 +433,11 @@ def make_cursor_url_callback(
|
||||
return cursor_url_callback
|
||||
|
||||
|
||||
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
|
||||
class GithubConnector(
|
||||
CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
@@ -803,6 +813,87 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
def _iter_slim_docs(
|
||||
self,
|
||||
include_permissions: bool,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = self.fetch_configured_repos()
|
||||
batch: list[SlimDocument | HierarchyNode] = []
|
||||
|
||||
for repo in repos:
|
||||
external_access = (
|
||||
get_external_access_permission(repo, self.github_client)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
|
||||
if self.include_prs:
|
||||
try:
|
||||
for pr in repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
):
|
||||
batch.append(
|
||||
SlimDocument(
|
||||
id=pr.html_url, external_access=external_access
|
||||
)
|
||||
)
|
||||
if len(batch) >= SLIM_BATCH_SIZE:
|
||||
yield batch
|
||||
batch = []
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"github_slim_docs: Stop signal detected"
|
||||
)
|
||||
except RateLimitExceededException:
|
||||
sleep_after_rate_limit_exception(self.github_client)
|
||||
|
||||
if self.include_issues:
|
||||
try:
|
||||
for issue in repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
):
|
||||
if issue.pull_request is not None:
|
||||
continue
|
||||
batch.append(
|
||||
SlimDocument(
|
||||
id=issue.html_url, external_access=external_access
|
||||
)
|
||||
)
|
||||
if len(batch) >= SLIM_BATCH_SIZE:
|
||||
yield batch
|
||||
batch = []
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"github_slim_docs: Stop signal detected"
|
||||
)
|
||||
except RateLimitExceededException:
|
||||
sleep_after_rate_limit_exception(self.github_client)
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._iter_slim_docs(include_permissions=False, callback=callback)
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._iter_slim_docs(include_permissions=True, callback=callback)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
|
||||
|
||||
@@ -75,6 +75,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -207,6 +208,7 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
@@ -1754,6 +1756,7 @@ class GoogleDriveConnector(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
slim_batch: list[SlimDocument | HierarchyNode] = []
|
||||
@@ -1763,9 +1766,13 @@ class GoogleDriveConnector(
|
||||
nonlocal files_batch, slim_batch
|
||||
|
||||
# Get new ancestor hierarchy nodes first
|
||||
permission_sync_context = PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
new_ancestors = self._get_new_ancestors_for_files(
|
||||
files=files_batch,
|
||||
@@ -1779,10 +1786,7 @@ class GoogleDriveConnector(
|
||||
if doc := build_slim_document(
|
||||
self.creds,
|
||||
file.drive_file,
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
permission_sync_context,
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
@@ -1822,11 +1826,12 @@ class GoogleDriveConnector(
|
||||
if files_batch:
|
||||
yield _yield_slim_batch()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
def _retrieve_all_slim_docs_impl(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
@@ -1836,13 +1841,34 @@ class GoogleDriveConnector(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=include_permissions,
|
||||
)
|
||||
logger.info("Drive perm sync: Slim doc retrieval complete")
|
||||
|
||||
logger.info("Drive slim doc retrieval complete")
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
raise
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs_impl(
|
||||
start=start, end=end, callback=callback, include_permissions=False
|
||||
)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs_impl(
|
||||
start=start, end=end, callback=callback, include_permissions=True
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._creds is None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
@@ -53,6 +54,21 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
|
||||
|
||||
Payloads written before the fix for serializing Google credentials into
|
||||
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
|
||||
Once every install has re-uploaded their Google credentials the legacy
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
@@ -162,12 +178,13 @@ def build_service_account_creds(
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
@@ -188,12 +205,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
return GoogleAppCredentials(**creds)
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
@@ -201,10 +218,14 @@ def upsert_google_app_cred(
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
KV_GOOGLE_DRIVE_CRED_KEY,
|
||||
app_credentials.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -220,12 +241,14 @@ def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
return GoogleServiceAccountKey(**creds)
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
@@ -234,12 +257,14 @@ def upsert_service_account_key(
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.json(),
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -123,6 +123,9 @@ class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -33,9 +35,18 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
)
|
||||
|
||||
|
||||
class SectionType(str, Enum):
|
||||
"""Discriminator for Section subclasses."""
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
TABULAR = "tabular"
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""Base section class with common attributes"""
|
||||
|
||||
type: SectionType
|
||||
link: str | None = None
|
||||
text: str | None = None
|
||||
image_file_id: str | None = None
|
||||
@@ -44,6 +55,7 @@ class Section(BaseModel):
|
||||
class TextSection(Section):
|
||||
"""Section containing text content"""
|
||||
|
||||
type: Literal[SectionType.TEXT] = SectionType.TEXT
|
||||
text: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
@@ -53,12 +65,25 @@ class TextSection(Section):
|
||||
class ImageSection(Section):
|
||||
"""Section containing an image reference"""
|
||||
|
||||
type: Literal[SectionType.IMAGE] = SectionType.IMAGE
|
||||
image_file_id: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
return sys.getsizeof(self.image_file_id) + sys.getsizeof(self.link)
|
||||
|
||||
|
||||
class TabularSection(Section):
|
||||
"""Section containing tabular data (csv/tsv content, or one sheet of
|
||||
an xlsx workbook rendered as CSV)."""
|
||||
|
||||
type: Literal[SectionType.TABULAR] = SectionType.TABULAR
|
||||
text: str # CSV representation in a string
|
||||
link: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
return sys.getsizeof(self.text) + sys.getsizeof(self.link)
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
"""Basic Information for the owner of a document, any of the fields can be left as None
|
||||
Display fallback goes as follows:
|
||||
@@ -134,7 +159,6 @@ class BasicExpertInfo(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model_dict: dict[str, Any]) -> "BasicExpertInfo":
|
||||
|
||||
first_name = cast(str, model_dict.get("FirstName"))
|
||||
last_name = cast(str, model_dict.get("LastName"))
|
||||
email = cast(str, model_dict.get("Email"))
|
||||
@@ -161,7 +185,7 @@ class DocumentBase(BaseModel):
|
||||
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
|
||||
|
||||
id: str | None = None
|
||||
sections: list[TextSection | ImageSection]
|
||||
sections: Sequence[TextSection | ImageSection | TabularSection]
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
@@ -371,12 +395,9 @@ class IndexingDocument(Document):
|
||||
)
|
||||
else:
|
||||
section_len = sum(
|
||||
(
|
||||
len(section.text)
|
||||
if isinstance(section, TextSection) and section.text is not None
|
||||
else 0
|
||||
)
|
||||
len(section.text) if section.text is not None else 0
|
||||
for section in self.sections
|
||||
if isinstance(section, (TextSection, TabularSection))
|
||||
)
|
||||
|
||||
return title_len + section_len
|
||||
|
||||
@@ -335,6 +335,7 @@ def update_document_set(
|
||||
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
|
||||
)
|
||||
|
||||
document_set_row.name = document_set_update_request.name
|
||||
document_set_row.description = document_set_update_request.description
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
@@ -346,6 +347,25 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _safe_close_session(session: Session) -> None:
|
||||
"""Close a session, catching connection-closed errors during cleanup.
|
||||
|
||||
Long-running operations (e.g. multi-model LLM loops) can hold a session
|
||||
open for minutes. If the underlying connection is dropped by cloud
|
||||
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
|
||||
timeouts, etc.), the implicit rollback in Session.close() raises
|
||||
OperationalError or InterfaceError. Since the work is already complete,
|
||||
we log and move on — SQLAlchemy internally invalidates the connection
|
||||
for pool recycling.
|
||||
"""
|
||||
try:
|
||||
session.close()
|
||||
except DBAPIError:
|
||||
logger.warning(
|
||||
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -358,8 +378,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
session = Session(bind=engine, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
return
|
||||
|
||||
# Create connection with schema translation to handle querying the right schema
|
||||
@@ -367,8 +390,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
session = Session(bind=connection, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -899,6 +899,7 @@ def create_index_attempt_error(
|
||||
failure: ConnectorFailure,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
exc = failure.exception
|
||||
new_error = IndexAttemptError(
|
||||
index_attempt_id=index_attempt_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
@@ -921,6 +922,7 @@ def create_index_attempt_error(
|
||||
),
|
||||
failure_message=failure.failure_message,
|
||||
is_resolved=False,
|
||||
error_type=type(exc).__name__ if exc else None,
|
||||
)
|
||||
db_session.add(new_error)
|
||||
db_session.commit()
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -83,47 +84,51 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
db_session: Session | None = None,
|
||||
) -> int:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
|
||||
Returns the id of the newly created Memory row.
|
||||
"""
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
db_session: Session | None = None,
|
||||
) -> int | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
Returns the id of the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
@@ -2422,6 +2422,8 @@ class IndexAttemptError(Base):
|
||||
failure_message: Mapped[str] = mapped_column(Text)
|
||||
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
error_type: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
|
||||
@@ -7,8 +7,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
@@ -22,6 +20,7 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
@@ -184,6 +183,14 @@ def generate_final_report(
|
||||
return has_reasoned
|
||||
|
||||
|
||||
def _get_research_agent_tool_id() -> int:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -193,7 +200,6 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt: str | None, # noqa: ARG001
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -717,6 +723,7 @@ def run_deep_research_llm_loop(
|
||||
simple_chat_history.append(assistant_with_tools)
|
||||
|
||||
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
|
||||
research_agent_tool_id = _get_research_agent_tool_id()
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
@@ -737,10 +744,7 @@ def run_deep_research_llm_loop(
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
tool_id=research_agent_tool_id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
|
||||
@@ -463,29 +463,13 @@ def _remove_empty_runs(
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
def xlsx_sheet_extraction(file: IO[Any], file_name: str = "") -> list[tuple[str, str]]:
|
||||
"""
|
||||
Converts each sheet in the excel file to a csv condensed string.
|
||||
Returns a string and the worksheet title for each worksheet
|
||||
|
||||
# md = get_markitdown_converter()
|
||||
# stream_info = StreamInfo(
|
||||
# mimetype=SPREADSHEET_MIME_TYPE, filename=file_name or None, extension=".xlsx"
|
||||
# )
|
||||
# try:
|
||||
# workbook = md.convert(to_bytesio(file), stream_info=stream_info)
|
||||
# except (
|
||||
# BadZipFile,
|
||||
# ValueError,
|
||||
# FileConversionException,
|
||||
# UnsupportedFormatException,
|
||||
# ) as e:
|
||||
# error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
# if file_name.startswith("~"):
|
||||
# logger.debug(error_str + " (this is expected for files with ~)")
|
||||
# else:
|
||||
# logger.warning(error_str)
|
||||
# return ""
|
||||
# return workbook.markdown
|
||||
Returns a list of (csv_text, sheet)
|
||||
"""
|
||||
try:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
except BadZipFile as e:
|
||||
@@ -494,23 +478,30 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
logger.debug(error_str + " (this is expected for files with ~)")
|
||||
else:
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
return []
|
||||
except Exception as e:
|
||||
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
|
||||
logger.error(
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
return []
|
||||
raise
|
||||
|
||||
text_content = []
|
||||
sheets: list[tuple[str, str]] = []
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
text_content.append(buf.getvalue().rstrip("\n"))
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
csv_text = buf.getvalue().rstrip("\n")
|
||||
if csv_text.strip():
|
||||
sheets.append((csv_text, sheet.title))
|
||||
return sheets
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
sheets = xlsx_sheet_extraction(file, file_name)
|
||||
return TEXT_SECTION_SEPARATOR.join(csv_text for csv_text, _title in sheets)
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
|
||||
@@ -16,16 +14,14 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking import DocumentChunker
|
||||
from onyx.indexing.chunking import extract_blurb
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
# actually help quality at all
|
||||
@@ -154,9 +150,6 @@ class Chunker:
|
||||
self.tokenizer = tokenizer
|
||||
self.callback = callback
|
||||
|
||||
self.max_context = 0
|
||||
self.prompt_tokens = 0
|
||||
|
||||
# Create a token counter function that returns the count instead of the tokens
|
||||
def token_counter(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
@@ -186,234 +179,12 @@ class Chunker:
|
||||
else None
|
||||
)
|
||||
|
||||
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
|
||||
"""
|
||||
Splits the text into smaller chunks based on token count to ensure
|
||||
no chunk exceeds the content_token_limit.
|
||||
"""
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
chunks = []
|
||||
start = 0
|
||||
total_tokens = len(tokens)
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
|
||||
def _extract_blurb(self, text: str) -> str:
|
||||
"""
|
||||
Extract a short blurb from the text (first chunk of size `blurb_size`).
|
||||
"""
|
||||
# chunker is in `text` mode
|
||||
texts = cast(list[str], self.blurb_splitter.chunk(text))
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
|
||||
"""
|
||||
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
|
||||
"""
|
||||
if self.mini_chunk_splitter and chunk_text.strip():
|
||||
# chunker is in `text` mode
|
||||
return cast(list[str], self.mini_chunk_splitter.chunk(chunk_text))
|
||||
return None
|
||||
|
||||
# ADDED: extra param image_url to store in the chunk
|
||||
def _create_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
chunks_list: list[DocAwareChunk],
|
||||
text: str,
|
||||
links: dict[int, str],
|
||||
is_continuation: bool = False,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
image_file_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper to create a new DocAwareChunk, append it to chunks_list.
|
||||
"""
|
||||
new_chunk = DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks_list),
|
||||
blurb=self._extract_blurb(text),
|
||||
content=text,
|
||||
source_links=links or {0: ""},
|
||||
image_file_id=image_file_id,
|
||||
section_continuation=is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
|
||||
self._document_chunker = DocumentChunker(
|
||||
tokenizer=tokenizer,
|
||||
blurb_splitter=self.blurb_splitter,
|
||||
chunk_splitter=self.chunk_splitter,
|
||||
mini_chunk_splitter=self.mini_chunk_splitter,
|
||||
)
|
||||
chunks_list.append(new_chunk)
|
||||
|
||||
def _chunk_document_with_sections(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Loops through sections of the document, converting them into one or more chunks.
|
||||
Works with processed sections that are base Section objects.
|
||||
"""
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
# Get section text and other attributes
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link_text = section.link or ""
|
||||
image_url = section.image_file_id
|
||||
|
||||
# If there is no useful content, skip
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
logger.warning(
|
||||
f"Skipping empty or irrelevant section in doc {document.semantic_identifier}, link={section_link_text}"
|
||||
)
|
||||
continue
|
||||
|
||||
# CASE 1: If this section has an image, force a separate chunk
|
||||
if image_url:
|
||||
# First, if we have any partially built text chunk, finalize it
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
is_continuation=False,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# Create a chunk specifically for this image section
|
||||
# (Using the text summary that was generated during processing)
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
section_text,
|
||||
links={0: section_link_text} if section_link_text else {},
|
||||
image_file_id=image_url,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
# Continue to next section
|
||||
continue
|
||||
|
||||
# CASE 2: Normal text section
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
|
||||
# If the section is large on its own, split it separately
|
||||
if section_token_count > content_token_limit:
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# chunker is in `text` mode
|
||||
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
|
||||
for i, split_text in enumerate(split_texts):
|
||||
# If even the split_text is bigger than strict limit, further split
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and len(self.tokenizer.encode(split_text)) > content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
small_chunk,
|
||||
{0: section_link_text},
|
||||
is_continuation=(j != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
else:
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
split_text,
|
||||
{0: section_link_text},
|
||||
is_continuation=(i != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
continue
|
||||
|
||||
# If we can still fit this section into the current chunk, do so
|
||||
current_token_count = len(self.tokenizer.encode(chunk_text))
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
if chunk_text:
|
||||
chunk_text += SECTION_SEPARATOR
|
||||
chunk_text += section_text
|
||||
link_offsets[current_offset] = section_link_text
|
||||
else:
|
||||
# finalize the existing chunk
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
# start a new chunk
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section_text
|
||||
|
||||
# finalize any leftover text chunk
|
||||
if chunk_text.strip() or not chunks:
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets or {0: ""}, # safe default
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
return chunks
|
||||
|
||||
def _handle_single_document(
|
||||
self, document: IndexingDocument
|
||||
@@ -423,7 +194,10 @@ class Chunker:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
|
||||
# Title prep
|
||||
title = self._extract_blurb(document.get_title_for_document_index() or "")
|
||||
title = extract_blurb(
|
||||
document.get_title_for_document_index() or "",
|
||||
self.blurb_splitter,
|
||||
)
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(self.tokenizer.encode(title_prefix))
|
||||
|
||||
@@ -491,7 +265,7 @@ class Chunker:
|
||||
# Use processed_sections if available (IndexingDocument), otherwise use original sections
|
||||
sections_to_chunk = document.processed_sections
|
||||
|
||||
normal_chunks = self._chunk_document_with_sections(
|
||||
normal_chunks = self._document_chunker.chunk(
|
||||
document,
|
||||
sections_to_chunk,
|
||||
title_prefix,
|
||||
|
||||
7
backend/onyx/indexing/chunking/__init__.py
Normal file
7
backend/onyx/indexing/chunking/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from onyx.indexing.chunking.document_chunker import DocumentChunker
|
||||
from onyx.indexing.chunking.section_chunker import extract_blurb
|
||||
|
||||
__all__ = [
|
||||
"DocumentChunker",
|
||||
"extract_blurb",
|
||||
]
|
||||
111
backend/onyx/indexing/chunking/document_chunker.py
Normal file
111
backend/onyx/indexing/chunking/document_chunker.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SectionType
|
||||
from onyx.indexing.chunking.image_section_chunker import ImageChunker
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.text_section_chunker import TextChunker
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentChunker:
|
||||
"""Converts a document's processed sections into DocAwareChunks.
|
||||
|
||||
Drop-in replacement for `Chunker._chunk_document_with_sections`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizer,
|
||||
blurb_splitter: SentenceChunker,
|
||||
chunk_splitter: SentenceChunker,
|
||||
mini_chunk_splitter: SentenceChunker | None = None,
|
||||
) -> None:
|
||||
self.blurb_splitter = blurb_splitter
|
||||
self.mini_chunk_splitter = mini_chunk_splitter
|
||||
|
||||
self._dispatch: dict[SectionType, SectionChunker] = {
|
||||
SectionType.TEXT: TextChunker(
|
||||
tokenizer=tokenizer,
|
||||
chunk_splitter=chunk_splitter,
|
||||
),
|
||||
SectionType.IMAGE: ImageChunker(),
|
||||
}
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
payloads = self._collect_section_payloads(
|
||||
document=document,
|
||||
sections=sections,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
if not payloads:
|
||||
payloads.append(ChunkPayload(text="", links={0: ""}))
|
||||
|
||||
return [
|
||||
payload.to_doc_aware_chunk(
|
||||
document=document,
|
||||
chunk_id=idx,
|
||||
blurb_splitter=self.blurb_splitter,
|
||||
mini_chunk_splitter=self.mini_chunk_splitter,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
for idx, payload in enumerate(payloads)
|
||||
]
|
||||
|
||||
def _collect_section_payloads(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
content_token_limit: int,
|
||||
) -> list[ChunkPayload]:
|
||||
accumulator = AccumulatorState()
|
||||
payloads: list[ChunkPayload] = []
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
logger.warning(
|
||||
f"Skipping empty or irrelevant section in doc "
|
||||
f"{document.semantic_identifier}, link={section.link}"
|
||||
)
|
||||
continue
|
||||
|
||||
chunker = self._select_chunker(section)
|
||||
result = chunker.chunk_section(
|
||||
section=section,
|
||||
accumulator=accumulator,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
payloads.extend(result.payloads)
|
||||
accumulator = result.accumulator
|
||||
|
||||
# Final flush — any leftover buffered text becomes one last payload.
|
||||
payloads.extend(accumulator.flush_to_list())
|
||||
|
||||
return payloads
|
||||
|
||||
def _select_chunker(self, section: Section) -> SectionChunker:
|
||||
try:
|
||||
return self._dispatch[section.type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No SectionChunker registered for type={section.type}")
|
||||
35
backend/onyx/indexing/chunking/image_section_chunker.py
Normal file
35
backend/onyx/indexing/chunking/image_section_chunker.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.utils.text_processing import clean_text
|
||||
|
||||
|
||||
class ImageChunker(SectionChunker):
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int, # noqa: ARG002
|
||||
) -> SectionChunkerOutput:
|
||||
assert section.image_file_id is not None
|
||||
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link = section.link or ""
|
||||
|
||||
# Flush any partially built text chunks
|
||||
payloads = accumulator.flush_to_list()
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=section_text,
|
||||
links={0: section_link} if section_link else {},
|
||||
image_file_id=section.image_file_id,
|
||||
is_continuation=False,
|
||||
)
|
||||
)
|
||||
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
100
backend/onyx/indexing/chunking/section_chunker.py
Normal file
100
backend/onyx/indexing/chunking/section_chunker.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_splitter: SentenceChunker) -> str:
|
||||
texts = cast(list[str], blurb_splitter.chunk(text))
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
|
||||
def get_mini_chunk_texts(
|
||||
chunk_text: str,
|
||||
mini_chunk_splitter: SentenceChunker | None,
|
||||
) -> list[str] | None:
|
||||
if mini_chunk_splitter and chunk_text.strip():
|
||||
return list(cast(Sequence[str], mini_chunk_splitter.chunk(chunk_text)))
|
||||
return None
|
||||
|
||||
|
||||
class ChunkPayload(BaseModel):
|
||||
"""Section-local chunk content without document-scoped fields.
|
||||
|
||||
The orchestrator upgrades these to DocAwareChunks via
|
||||
`to_doc_aware_chunk` after assigning chunk_ids and attaching
|
||||
title/metadata.
|
||||
"""
|
||||
|
||||
text: str
|
||||
links: dict[int, str]
|
||||
is_continuation: bool = False
|
||||
image_file_id: str | None = None
|
||||
|
||||
def to_doc_aware_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
chunk_id: int,
|
||||
blurb_splitter: SentenceChunker,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
mini_chunk_splitter: SentenceChunker | None = None,
|
||||
) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=chunk_id,
|
||||
blurb=extract_blurb(self.text, blurb_splitter),
|
||||
content=self.text,
|
||||
source_links=self.links or {0: ""},
|
||||
image_file_id=self.image_file_id,
|
||||
section_continuation=self.is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=get_mini_chunk_texts(self.text, mini_chunk_splitter),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
class AccumulatorState(BaseModel):
|
||||
"""Cross-section text buffer threaded through SectionChunkers."""
|
||||
|
||||
text: str = ""
|
||||
link_offsets: dict[int, str] = Field(default_factory=dict)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.text.strip()
|
||||
|
||||
def flush_to_list(self) -> list[ChunkPayload]:
|
||||
if self.is_empty():
|
||||
return []
|
||||
return [ChunkPayload(text=self.text, links=self.link_offsets)]
|
||||
|
||||
|
||||
class SectionChunkerOutput(BaseModel):
|
||||
payloads: list[ChunkPayload]
|
||||
accumulator: AccumulatorState
|
||||
|
||||
|
||||
class SectionChunker(ABC):
|
||||
@abstractmethod
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput: ...
|
||||
129
backend/onyx/indexing/chunking/text_section_chunker.py
Normal file
129
backend/onyx/indexing/chunking/text_section_chunker.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.constants import SECTION_SEPARATOR
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
|
||||
|
||||
class TextChunker(SectionChunker):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizer,
|
||||
chunk_splitter: SentenceChunker,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.chunk_splitter = chunk_splitter
|
||||
|
||||
self.section_separator_token_count = count_tokens(
|
||||
SECTION_SEPARATOR,
|
||||
self.tokenizer,
|
||||
)
|
||||
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link = section.link or ""
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
|
||||
# Oversized — flush buffer and split the section
|
||||
if section_token_count > content_token_limit:
|
||||
return self._handle_oversized_section(
|
||||
section_text=section_text,
|
||||
section_link=section_link,
|
||||
accumulator=accumulator,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
current_token_count = count_tokens(accumulator.text, self.tokenizer)
|
||||
next_section_tokens = self.section_separator_token_count + section_token_count
|
||||
|
||||
# Fits — extend the accumulator
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
offset = len(shared_precompare_cleanup(accumulator.text))
|
||||
new_text = accumulator.text
|
||||
if new_text:
|
||||
new_text += SECTION_SEPARATOR
|
||||
new_text += section_text
|
||||
return SectionChunkerOutput(
|
||||
payloads=[],
|
||||
accumulator=AccumulatorState(
|
||||
text=new_text,
|
||||
link_offsets={**accumulator.link_offsets, offset: section_link},
|
||||
),
|
||||
)
|
||||
|
||||
# Doesn't fit — flush buffer and restart with this section
|
||||
return SectionChunkerOutput(
|
||||
payloads=accumulator.flush_to_list(),
|
||||
accumulator=AccumulatorState(
|
||||
text=section_text,
|
||||
link_offsets={0: section_link},
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_oversized_section(
|
||||
self,
|
||||
section_text: str,
|
||||
section_link: str,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
payloads = accumulator.flush_to_list()
|
||||
|
||||
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
|
||||
for i, split_text in enumerate(split_texts):
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and count_tokens(split_text, self.tokenizer) > content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=small_chunk,
|
||||
links={0: section_link},
|
||||
is_continuation=(j != 0),
|
||||
)
|
||||
)
|
||||
else:
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=split_text,
|
||||
links={0: section_link},
|
||||
is_continuation=(i != 0),
|
||||
)
|
||||
)
|
||||
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
|
||||
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
total_tokens = len(tokens)
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
@@ -3,6 +3,8 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
@@ -291,6 +293,13 @@ def embed_chunks_with_failure_handling(
|
||||
)
|
||||
embedded_chunks.extend(doc_embedded_chunks)
|
||||
except Exception as e:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "embedding")
|
||||
scope.set_tag("doc_id", doc_id)
|
||||
if tenant_id:
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
scope.fingerprint = ["embedding-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Protocol
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -332,6 +333,13 @@ def index_doc_batch_with_handler(
|
||||
except Exception as e:
|
||||
# don't log the batch directly, it's too much text
|
||||
document_ids = [doc.id for doc in document_batch]
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "indexing_pipeline")
|
||||
scope.set_tag("tenant_id", tenant_id)
|
||||
scope.set_tag("batch_size", str(len(document_batch)))
|
||||
scope.set_extra("document_ids", document_ids)
|
||||
scope.fingerprint = ["indexing-pipeline-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(f"Failed to index document batch: {document_ids}")
|
||||
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
@@ -542,6 +550,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
**document.model_dump(),
|
||||
processed_sections=[
|
||||
Section(
|
||||
type=section.type,
|
||||
text=section.text if isinstance(section, TextSection) else "",
|
||||
link=section.link,
|
||||
image_file_id=(
|
||||
@@ -566,6 +575,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
if isinstance(section, ImageSection):
|
||||
# Default section with image path preserved - ensure text is always a string
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
link=section.link,
|
||||
image_file_id=section.image_file_id,
|
||||
text="", # Initialize with empty string
|
||||
@@ -609,6 +619,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
# For TextSection, create a base Section with text and link
|
||||
elif isinstance(section, TextSection):
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
text=section.text or "", # Ensure text is always a string, not None
|
||||
link=section.link,
|
||||
image_file_id=None,
|
||||
|
||||
@@ -6,6 +6,7 @@ from itertools import chain
|
||||
from itertools import groupby
|
||||
|
||||
import httpx
|
||||
import sentry_sdk
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
@@ -88,6 +89,12 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
with sentry_sdk.new_scope() as scope:
|
||||
scope.set_tag("stage", "vector_db_write")
|
||||
scope.set_tag("doc_id", doc_id)
|
||||
scope.set_tag("tenant_id", index_batch_params.tenant_id)
|
||||
scope.fingerprint = ["vector-db-write-failure", type(e).__name__]
|
||||
sentry_sdk.capture_exception(e)
|
||||
logger.exception(
|
||||
f"Failed to write document chunks for '{doc_id}' to vector db"
|
||||
)
|
||||
|
||||
@@ -434,11 +434,14 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
lifespan=lifespan_override or lifespan,
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
before_send=_add_instance_tags,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"date-fns": "^4.1.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"lucide-react": "^0.562.0",
|
||||
"next": "16.1.7",
|
||||
"next": "16.2.3",
|
||||
"next-themes": "^0.4.6",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "19.2.3",
|
||||
@@ -961,9 +961,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@hono/node-server": {
|
||||
"version": "1.19.10",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
|
||||
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
|
||||
"version": "1.19.13",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.13.tgz",
|
||||
"integrity": "sha512-TsQLe4i2gvoTtrHje625ngThGBySOgSK3Xo2XRYOdqGN1teR8+I7vchQC46uLJi8OF62YTYA3AhSpumtkhsaKQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18.14.1"
|
||||
@@ -1711,9 +1711,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
|
||||
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.2.3.tgz",
|
||||
"integrity": "sha512-ZWXyj4uNu4GCWQw9cjRxWlbD+33mcDszIo9iQxFnBX3Wmgq9ulaSJcl6VhuWx5pCWqqD+9W6Wfz7N0lM5lYPMA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
@@ -1727,9 +1727,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
|
||||
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.2.3.tgz",
|
||||
"integrity": "sha512-u37KDKTKQ+OQLvY+z7SNXixwo4Q2/IAJFDzU1fYe66IbCE51aDSAzkNDkWmLN0yjTUh4BKBd+hb69jYn6qqqSg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1743,9 +1743,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
|
||||
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.2.3.tgz",
|
||||
"integrity": "sha512-gHjL/qy6Q6CG3176FWbAKyKh9IfntKZTB3RY/YOJdDFpHGsUDXVH38U4mMNpHVGXmeYW4wj22dMp1lTfmu/bTQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1759,9 +1759,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.2.3.tgz",
|
||||
"integrity": "sha512-U6vtblPtU/P14Y/b/n9ZY0GOxbbIhTFuaFR7F4/uMBidCi2nSdaOFhA0Go81L61Zd6527+yvuX44T4ksnf8T+Q==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1775,9 +1775,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.2.3.tgz",
|
||||
"integrity": "sha512-/YV0LgjHUmfhQpn9bVoGc4x4nan64pkhWR5wyEV8yCOfwwrH630KpvRg86olQHTwHIn1z59uh6JwKvHq1h4QEw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1791,9 +1791,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.2.3.tgz",
|
||||
"integrity": "sha512-/HiWEcp+WMZ7VajuiMEFGZ6cg0+aYZPqCJD3YJEfpVWQsKYSjXQG06vJP6F1rdA03COD9Fef4aODs3YxKx+RDQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1807,9 +1807,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.2.3.tgz",
|
||||
"integrity": "sha512-Kt44hGJfZSefebhk/7nIdivoDr3Ugp5+oNz9VvF3GUtfxutucUIHfIO0ZYO8QlOPDQloUVQn4NVC/9JvHRk9hw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1823,9 +1823,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.2.3.tgz",
|
||||
"integrity": "sha512-O2NZ9ie3Tq6xj5Z5CSwBT3+aWAMW2PIZ4egUi9MaWLkwaehgtB7YZjPm+UpcNpKOme0IQuqDcor7BsW6QBiQBw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1839,9 +1839,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.2.3.tgz",
|
||||
"integrity": "sha512-Ibm29/GgB/ab5n7XKqlStkm54qqZE8v2FnijUPBgrd67FWrac45o/RsNlaOWjme/B5UqeWt/8KM4aWBwA1D2Kw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -7427,9 +7427,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
|
||||
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
|
||||
"version": "4.12.12",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.12.tgz",
|
||||
"integrity": "sha512-p1JfQMKaceuCbpJKAPKVqyqviZdS0eUxH9v82oWo1kb9xjQ5wA6iP3FNVAPDFlz5/p7d45lO+BpSk1tuSZMF4Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
@@ -8637,9 +8637,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/lodash": {
|
||||
"version": "4.17.23",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz",
|
||||
"integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==",
|
||||
"version": "4.18.1",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.18.1.tgz",
|
||||
"integrity": "sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/lodash.merge": {
|
||||
@@ -8978,12 +8978,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
|
||||
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.2.3.tgz",
|
||||
"integrity": "sha512-9V3zV4oZFza3PVev5/poB9g0dEafVcgNyQ8eTRop8GvxZjV2G15FC5ARuG1eFD42QgeYkzJBJzHghNP8Ad9xtA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/env": "16.1.7",
|
||||
"@next/env": "16.2.3",
|
||||
"@swc/helpers": "0.5.15",
|
||||
"baseline-browser-mapping": "^2.9.19",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
@@ -8997,15 +8997,15 @@
|
||||
"node": ">=20.9.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "16.1.7",
|
||||
"@next/swc-darwin-x64": "16.1.7",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.7",
|
||||
"@next/swc-linux-arm64-musl": "16.1.7",
|
||||
"@next/swc-linux-x64-gnu": "16.1.7",
|
||||
"@next/swc-linux-x64-musl": "16.1.7",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.7",
|
||||
"@next/swc-win32-x64-msvc": "16.1.7",
|
||||
"sharp": "^0.34.4"
|
||||
"@next/swc-darwin-arm64": "16.2.3",
|
||||
"@next/swc-darwin-x64": "16.2.3",
|
||||
"@next/swc-linux-arm64-gnu": "16.2.3",
|
||||
"@next/swc-linux-arm64-musl": "16.2.3",
|
||||
"@next/swc-linux-x64-gnu": "16.2.3",
|
||||
"@next/swc-linux-x64-musl": "16.2.3",
|
||||
"@next/swc-win32-arm64-msvc": "16.2.3",
|
||||
"@next/swc-win32-x64-msvc": "16.2.3",
|
||||
"sharp": "^0.34.5"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@opentelemetry/api": "^1.1.0",
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"date-fns": "^4.1.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"lucide-react": "^0.562.0",
|
||||
"next": "16.1.7",
|
||||
"next": "16.2.3",
|
||||
"next-themes": "^0.4.6",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "19.2.3",
|
||||
|
||||
@@ -618,6 +618,7 @@ done
|
||||
"app.kubernetes.io/managed-by": "onyx",
|
||||
"onyx.app/sandbox-id": sandbox_id,
|
||||
"onyx.app/tenant-id": tenant_id,
|
||||
"admission.datadoghq.com/enabled": "false",
|
||||
},
|
||||
),
|
||||
spec=pod_spec,
|
||||
|
||||
@@ -63,6 +63,7 @@ class DocumentSetCreationRequest(BaseModel):
|
||||
|
||||
class DocumentSetUpdateRequest(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
cc_pair_ids: list[int]
|
||||
is_public: bool
|
||||
|
||||
@@ -11,6 +11,9 @@ from onyx.db.notification import dismiss_notification
|
||||
from onyx.db.notification import get_notification_by_id
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
|
||||
from onyx.server.features.notifications.utils import (
|
||||
ensure_permissions_migration_notification,
|
||||
)
|
||||
from onyx.server.features.release_notes.utils import (
|
||||
ensure_release_notes_fresh_and_notify,
|
||||
)
|
||||
@@ -49,6 +52,13 @@ def get_notifications_api(
|
||||
except Exception:
|
||||
logger.exception("Failed to check for release notes in notifications endpoint")
|
||||
|
||||
try:
|
||||
ensure_permissions_migration_notification(user, db_session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
|
||||
)
|
||||
|
||||
notifications = [
|
||||
NotificationModel.from_model(notif)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=True)
|
||||
|
||||
21
backend/onyx/server/features/notifications/utils.py
Normal file
21
backend/onyx/server/features/notifications/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
|
||||
|
||||
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
|
||||
# Feature id "permissions_migration_v1" must not change after shipping —
|
||||
# it is the dedup key on (user_id, notif_type, additional_data).
|
||||
create_notification(
|
||||
user_id=user.id,
|
||||
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
|
||||
db_session=db_session,
|
||||
title="Permissions are changing in Onyx",
|
||||
description="Roles are moving to group-based permissions. Click for details.",
|
||||
additional_data={
|
||||
"feature": "permissions_migration_v1",
|
||||
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
|
||||
},
|
||||
)
|
||||
@@ -185,6 +185,10 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
for doc_set in persona.document_sets:
|
||||
for cc_pair in doc_set.connector_credential_pairs:
|
||||
sources.add(cc_pair.connector.source)
|
||||
for fed_ds in doc_set.federated_connectors:
|
||||
non_fed = fed_ds.federated_connector.source.to_non_federated_source()
|
||||
if non_fed is not None:
|
||||
sources.add(non_fed)
|
||||
|
||||
# Sources from hierarchy nodes
|
||||
for node in persona.hierarchy_nodes:
|
||||
@@ -195,6 +199,9 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if doc.parent_hierarchy_node:
|
||||
sources.add(doc.parent_hierarchy_node.source)
|
||||
|
||||
if persona.user_files:
|
||||
sources.add(DocumentSource.USER_FILE)
|
||||
|
||||
return MinimalPersonaSnapshot(
|
||||
# Core fields actually used by ChatPage
|
||||
id=persona.id,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
active task gauge, task duration histograms, queue wait time histograms,
|
||||
and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
@@ -71,6 +72,32 @@ TASK_REJECTED = Counter(
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_QUEUE_WAIT = Histogram(
|
||||
"onyx_celery_task_queue_wait_seconds",
|
||||
"Time a Celery task spent waiting in the queue before execution started",
|
||||
["task_name", "queue"],
|
||||
buckets=[
|
||||
0.1,
|
||||
0.5,
|
||||
1,
|
||||
5,
|
||||
30,
|
||||
60,
|
||||
300,
|
||||
600,
|
||||
1800,
|
||||
3600,
|
||||
7200,
|
||||
14400,
|
||||
28800,
|
||||
43200,
|
||||
86400,
|
||||
172800,
|
||||
432000,
|
||||
864000,
|
||||
],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
@@ -133,6 +160,13 @@ def on_celery_task_prerun(
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
|
||||
headers = getattr(task.request, "headers", None) or {}
|
||||
enqueued_at = headers.get("enqueued_at")
|
||||
if isinstance(enqueued_at, (int, float)):
|
||||
TASK_QUEUE_WAIT.labels(**labels).observe(
|
||||
max(0.0, time.time() - enqueued_at)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
104
backend/onyx/server/metrics/deletion_metrics.py
Normal file
104
backend/onyx/server/metrics/deletion_metrics.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Connector-deletion-specific Prometheus metrics.
|
||||
|
||||
Tracks the deletion lifecycle:
|
||||
1. Deletions started (taskset generated)
|
||||
2. Deletions completed (success or failure)
|
||||
3. Taskset duration (from taskset generation to completion or failure).
|
||||
Note: this measures the most recent taskset execution, NOT wall-clock
|
||||
time since the user triggered the deletion. When deletion is blocked by
|
||||
indexing/pruning/permissions, the fence is cleared and a fresh taskset
|
||||
is generated on each retry, resetting this timer.
|
||||
4. Deletion blocked by dependencies (indexing, pruning, permissions, etc.)
|
||||
5. Fence resets (stuck deletion recovery)
|
||||
|
||||
All metrics are labeled by tenant_id. cc_pair_id is intentionally excluded
|
||||
to avoid unbounded cardinality.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.deletion_metrics import (
|
||||
inc_deletion_started,
|
||||
inc_deletion_completed,
|
||||
observe_deletion_taskset_duration,
|
||||
inc_deletion_blocked,
|
||||
inc_deletion_fence_reset,
|
||||
)
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DELETION_STARTED = Counter(
|
||||
"onyx_deletion_started_total",
|
||||
"Connector deletions initiated (taskset generated)",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
DELETION_COMPLETED = Counter(
|
||||
"onyx_deletion_completed_total",
|
||||
"Connector deletions completed",
|
||||
["tenant_id", "outcome"],
|
||||
)
|
||||
|
||||
DELETION_TASKSET_DURATION = Histogram(
|
||||
"onyx_deletion_taskset_duration_seconds",
|
||||
"Duration of a connector deletion taskset, from taskset generation "
|
||||
"to completion or failure. Does not include time spent blocked on "
|
||||
"indexing/pruning/permissions before the taskset was generated.",
|
||||
["tenant_id", "outcome"],
|
||||
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600, 7200, 21600],
|
||||
)
|
||||
|
||||
DELETION_BLOCKED = Counter(
|
||||
"onyx_deletion_blocked_total",
|
||||
"Times deletion was blocked by a dependency",
|
||||
["tenant_id", "blocker"],
|
||||
)
|
||||
|
||||
DELETION_FENCE_RESET = Counter(
|
||||
"onyx_deletion_fence_reset_total",
|
||||
"Deletion fences reset due to missing celery tasks",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
|
||||
def inc_deletion_started(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_STARTED.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion started", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_completed(tenant_id: str, outcome: str) -> None:
|
||||
try:
|
||||
DELETION_COMPLETED.labels(tenant_id=tenant_id, outcome=outcome).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion completed", exc_info=True)
|
||||
|
||||
|
||||
def observe_deletion_taskset_duration(
|
||||
tenant_id: str, outcome: str, duration_seconds: float
|
||||
) -> None:
|
||||
try:
|
||||
DELETION_TASKSET_DURATION.labels(tenant_id=tenant_id, outcome=outcome).observe(
|
||||
duration_seconds
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion taskset duration", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_blocked(tenant_id: str, blocker: str) -> None:
|
||||
try:
|
||||
DELETION_BLOCKED.labels(tenant_id=tenant_id, blocker=blocker).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion blocked", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_fence_reset(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_FENCE_RESET.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion fence reset", exc_info=True)
|
||||
@@ -27,6 +27,7 @@ _DEFAULT_PORTS: dict[str, int] = {
|
||||
"docfetching": 9092,
|
||||
"docprocessing": 9093,
|
||||
"heavy": 9094,
|
||||
"light": 9095,
|
||||
}
|
||||
|
||||
_server_started = False
|
||||
|
||||
@@ -28,14 +28,14 @@ PRUNING_ENUMERATION_DURATION = Histogram(
|
||||
"onyx_pruning_enumeration_duration_seconds",
|
||||
"Duration of document ID enumeration from the source connector during pruning",
|
||||
["connector_type"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
buckets=[5, 60, 600, 1800, 3600, 10800, 21600],
|
||||
)
|
||||
|
||||
PRUNING_DIFF_DURATION = Histogram(
|
||||
"onyx_pruning_diff_duration_seconds",
|
||||
"Duration of diff computation and subtask dispatch during pruning",
|
||||
["connector_type"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
buckets=[0.1, 0.25, 0.5, 1, 2, 5, 15, 30, 60],
|
||||
)
|
||||
|
||||
PRUNING_RATE_LIMIT_ERRORS = Counter(
|
||||
|
||||
@@ -65,7 +65,8 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
multi_model_chat_enabled: bool | None = True
|
||||
search_ui_enabled: bool | None = True
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
@@ -89,7 +90,8 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
default=None,
|
||||
ge=0, # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
@@ -12,6 +11,8 @@ from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
|
||||
from onyx.db.usage import check_usage_limit
|
||||
from onyx.db.usage import UsageLimitExceededError
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitKeys
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -255,11 +256,14 @@ def check_usage_and_raise(
|
||||
"Please upgrade your plan or wait for the next billing period."
|
||||
)
|
||||
elif usage_type == UsageType.API_CALLS:
|
||||
detail = (
|
||||
f"API call limit exceeded for {user_type} account. "
|
||||
f"Calls: {int(e.current)}, Limit: {int(e.limit)} per week. "
|
||||
"Please upgrade your plan or wait for the next billing period."
|
||||
)
|
||||
if is_trial and e.limit == 0:
|
||||
detail = "API access is not available on trial accounts. Please upgrade to a paid plan to use the API and chat widget."
|
||||
else:
|
||||
detail = (
|
||||
f"API call limit exceeded for {user_type} account. "
|
||||
f"Calls: {int(e.current)}, Limit: {int(e.limit)} per week. "
|
||||
"Please upgrade your plan or wait for the next billing period."
|
||||
)
|
||||
else:
|
||||
detail = (
|
||||
f"Non-streaming API call limit exceeded for {user_type} account. "
|
||||
@@ -267,4 +271,4 @@ def check_usage_and_raise(
|
||||
"Please upgrade your plan or wait for the next billing period."
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=429, detail=detail)
|
||||
raise OnyxError(OnyxErrorCode.RATE_LIMITED, detail)
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -113,10 +114,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
db_session: Session | None = None,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
@@ -131,6 +132,33 @@ def construct_tools(
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
return _construct_tools_impl(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
custom_tool_config=custom_tool_config,
|
||||
file_reader_tool_config=file_reader_tool_config,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
|
||||
|
||||
def _construct_tools_impl(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
|
||||
) -> dict[int, list[Tool]]:
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
|
||||
@@ -17,6 +17,7 @@ def documents_to_indexing_documents(
|
||||
processed_sections = []
|
||||
for section in document.sections:
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
text=section.text or "",
|
||||
link=section.link,
|
||||
image_file_id=None,
|
||||
|
||||
@@ -26,7 +26,7 @@ aiolimiter==1.2.1
|
||||
# via voyageai
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
alembic==1.10.4
|
||||
alembic==1.18.4
|
||||
amqp==5.3.1
|
||||
# via kombu
|
||||
annotated-doc==0.0.4
|
||||
@@ -174,7 +174,7 @@ coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
courlan==1.3.2
|
||||
# via trafilatura
|
||||
cryptography==46.0.6
|
||||
cryptography==46.0.7
|
||||
# via
|
||||
# authlib
|
||||
# google-auth
|
||||
@@ -408,7 +408,7 @@ kombu==5.5.4
|
||||
# via celery
|
||||
kubernetes==31.0.0
|
||||
# via onyx
|
||||
langchain-core==1.2.22
|
||||
langchain-core==1.2.28
|
||||
langdetect==1.0.9
|
||||
# via unstructured
|
||||
langfuse==3.10.0
|
||||
@@ -583,13 +583,13 @@ pathable==0.4.4
|
||||
# via jsonschema-path
|
||||
pdfminer-six==20251107
|
||||
# via markitdown
|
||||
pillow==12.1.1
|
||||
pillow==12.2.0
|
||||
# via python-pptx
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
# fastmcp
|
||||
# zeep
|
||||
playwright==1.55.0
|
||||
playwright==1.58.0
|
||||
# via pytest-playwright
|
||||
pluggy==1.6.0
|
||||
# via pytest
|
||||
@@ -666,7 +666,9 @@ pyee==13.0.0
|
||||
# via playwright
|
||||
pygithub==2.5.0
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
# via
|
||||
# pytest
|
||||
# rich
|
||||
pyjwt==2.12.0
|
||||
# via
|
||||
# fastapi-users
|
||||
@@ -680,13 +682,13 @@ pynacl==1.6.2
|
||||
pypandoc-binary==1.16.2
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.9.2
|
||||
pypdf==6.10.0
|
||||
# via unstructured-client
|
||||
pyperclip==1.11.0
|
||||
# via fastmcp
|
||||
pyreadline3==3.5.4 ; sys_platform == 'win32'
|
||||
# via humanfriendly
|
||||
pytest==8.3.5
|
||||
pytest==9.0.3
|
||||
# via
|
||||
# pytest-base-url
|
||||
# pytest-mock
|
||||
@@ -694,7 +696,7 @@ pytest==8.3.5
|
||||
pytest-base-url==2.1.0
|
||||
# via pytest-playwright
|
||||
pytest-mock==3.12.0
|
||||
pytest-playwright==0.7.0
|
||||
pytest-playwright==0.7.2
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
|
||||
@@ -22,7 +22,7 @@ aiolimiter==1.2.1
|
||||
# via voyageai
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
alembic==1.10.4
|
||||
alembic==1.18.4
|
||||
# via pytest-alembic
|
||||
annotated-doc==0.0.4
|
||||
# via fastapi
|
||||
@@ -46,7 +46,7 @@ attrs==25.4.0
|
||||
# aiohttp
|
||||
# jsonschema
|
||||
# referencing
|
||||
black==25.1.0
|
||||
black==26.3.1
|
||||
boto3==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -95,7 +95,7 @@ comm==0.2.3
|
||||
# via ipykernel
|
||||
contourpy==1.3.3
|
||||
# via matplotlib
|
||||
cryptography==46.0.6
|
||||
cryptography==46.0.7
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
@@ -274,13 +274,13 @@ parameterized==0.9.0
|
||||
# via cohere
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pathspec==0.12.1
|
||||
pathspec==1.0.4
|
||||
# via
|
||||
# black
|
||||
# hatchling
|
||||
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
||||
# via ipython
|
||||
pillow==12.1.1
|
||||
pillow==12.2.0
|
||||
# via matplotlib
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
@@ -339,11 +339,12 @@ pygments==2.20.0
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
pyjwt==2.12.0
|
||||
# via mcp
|
||||
pyparsing==3.2.5
|
||||
# via matplotlib
|
||||
pytest==8.3.5
|
||||
pytest==9.0.3
|
||||
# via
|
||||
# pytest-alembic
|
||||
# pytest-asyncio
|
||||
@@ -369,6 +370,8 @@ python-dotenv==1.1.1
|
||||
# pytest-dotenv
|
||||
python-multipart==0.0.22
|
||||
# via mcp
|
||||
pytokens==0.4.1
|
||||
# via black
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via mcp
|
||||
pyyaml==6.0.3
|
||||
|
||||
@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.6
|
||||
cryptography==46.0.7
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
|
||||
@@ -91,7 +91,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.6
|
||||
cryptography==46.0.7
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
@@ -264,7 +264,7 @@ packaging==24.2
|
||||
# transformers
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
pillow==12.1.1
|
||||
pillow==12.2.0
|
||||
# via sentence-transformers
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
@@ -699,3 +700,43 @@ def test_specific_user_email_shared_with_me(
|
||||
|
||||
doc_titles = set(doc.semantic_identifier for doc in output.documents)
|
||||
assert doc_titles == set(expected)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_slim_retrieval_does_not_call_permissions_list(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""retrieve_all_slim_docs() must not call permissions().list for any file.
|
||||
|
||||
Pruning only needs file IDs — fetching permissions per file causes O(N) API
|
||||
calls that time out for tenants with large numbers of externally-owned files.
|
||||
"""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
include_my_drives=True,
|
||||
include_files_shared_with_me=False,
|
||||
shared_folder_urls=None,
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.connectors.google_drive.connector.execute_paginated_retrieval",
|
||||
wraps=execute_paginated_retrieval,
|
||||
) as mock_paginated:
|
||||
for batch in connector.retrieve_all_slim_docs():
|
||||
pass
|
||||
|
||||
permissions_calls = [
|
||||
c
|
||||
for c in mock_paginated.call_args_list
|
||||
if "permissions" in str(c.kwargs.get("retrieval_function", ""))
|
||||
]
|
||||
assert (
|
||||
len(permissions_calls) == 0
|
||||
), f"permissions().list was called {len(permissions_calls)} time(s) during pruning"
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
@@ -141,13 +142,15 @@ def load_all_from_connector(
|
||||
|
||||
def to_sections(
|
||||
documents: list[Document],
|
||||
) -> Iterator[TextSection | ImageSection]:
|
||||
) -> Iterator[TextSection | ImageSection | TabularSection]:
|
||||
for doc in documents:
|
||||
for section in doc.sections:
|
||||
yield section
|
||||
|
||||
|
||||
def to_text_sections(sections: Iterator[TextSection | ImageSection]) -> Iterator[str]:
|
||||
def to_text_sections(
|
||||
sections: Iterator[TextSection | ImageSection | TabularSection],
|
||||
) -> Iterator[str]:
|
||||
for section in sections:
|
||||
if isinstance(section, TextSection):
|
||||
yield section.text
|
||||
|
||||
@@ -38,38 +38,41 @@ class TestAddMemory:
|
||||
def test_add_memory_creates_row(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that add_memory inserts a new Memory row."""
|
||||
user_id = test_user.id
|
||||
memory = add_memory(
|
||||
memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="User prefers dark mode",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert memory.id is not None
|
||||
assert memory.user_id == user_id
|
||||
assert memory.memory_text == "User prefers dark mode"
|
||||
assert memory_id is not None
|
||||
|
||||
# Verify it persists
|
||||
fetched = db_session.get(Memory, memory.id)
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.user_id == user_id
|
||||
assert fetched.memory_text == "User prefers dark mode"
|
||||
|
||||
def test_add_multiple_memories(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that multiple memories can be added for the same user."""
|
||||
user_id = test_user.id
|
||||
m1 = add_memory(
|
||||
m1_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Favorite color is blue",
|
||||
db_session=db_session,
|
||||
)
|
||||
m2 = add_memory(
|
||||
m2_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Works in engineering",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert m1.id != m2.id
|
||||
assert m1.memory_text == "Favorite color is blue"
|
||||
assert m2.memory_text == "Works in engineering"
|
||||
assert m1_id != m2_id
|
||||
fetched_m1 = db_session.get(Memory, m1_id)
|
||||
fetched_m2 = db_session.get(Memory, m2_id)
|
||||
assert fetched_m1 is not None
|
||||
assert fetched_m2 is not None
|
||||
assert fetched_m1.memory_text == "Favorite color is blue"
|
||||
assert fetched_m2.memory_text == "Works in engineering"
|
||||
|
||||
|
||||
class TestUpdateMemoryAtIndex:
|
||||
@@ -82,15 +85,17 @@ class TestUpdateMemoryAtIndex:
|
||||
add_memory(user_id=user_id, memory_text="Memory 1", db_session=db_session)
|
||||
add_memory(user_id=user_id, memory_text="Memory 2", db_session=db_session)
|
||||
|
||||
updated = update_memory_at_index(
|
||||
updated_id = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=1,
|
||||
new_text="Updated Memory 1",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated Memory 1"
|
||||
assert updated_id is not None
|
||||
fetched = db_session.get(Memory, updated_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Updated Memory 1"
|
||||
|
||||
def test_update_memory_at_out_of_range_index(
|
||||
self, db_session: Session, test_user: User
|
||||
@@ -167,7 +172,7 @@ class TestMemoryCap:
|
||||
assert len(rows_before) == MAX_MEMORIES_PER_USER
|
||||
|
||||
# Add one more — should evict the oldest
|
||||
new_memory = add_memory(
|
||||
new_memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="New memory after cap",
|
||||
db_session=db_session,
|
||||
@@ -181,7 +186,7 @@ class TestMemoryCap:
|
||||
# Oldest ("Memory 0") should be gone; "Memory 1" is now the oldest
|
||||
assert rows_after[0].memory_text == "Memory 1"
|
||||
# Newest should be the one we just added
|
||||
assert rows_after[-1].id == new_memory.id
|
||||
assert rows_after[-1].id == new_memory_id
|
||||
assert rows_after[-1].memory_text == "New memory after cap"
|
||||
|
||||
|
||||
@@ -221,22 +226,26 @@ class TestGetMemoriesWithUserId:
|
||||
user_id = test_user_no_memories.id
|
||||
|
||||
# Add a memory
|
||||
memory = add_memory(
|
||||
memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert memory.memory_text == "Memory with use_memories off"
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Memory with use_memories off"
|
||||
|
||||
# Update that memory
|
||||
updated = update_memory_at_index(
|
||||
updated_id = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=0,
|
||||
new_text="Updated memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated memory with use_memories off"
|
||||
assert updated_id is not None
|
||||
fetched_updated = db_session.get(Memory, updated_id)
|
||||
assert fetched_updated is not None
|
||||
assert fetched_updated.memory_text == "Updated memory with use_memories off"
|
||||
|
||||
# Verify get_memories returns the updated memory
|
||||
context = get_memories(test_user_no_memories, db_session)
|
||||
|
||||
@@ -62,6 +62,7 @@ class DocumentSetManager:
|
||||
) -> bool:
|
||||
doc_set_update_request = {
|
||||
"id": document_set.id,
|
||||
"name": document_set.name,
|
||||
"description": document_set.description,
|
||||
"cc_pair_ids": document_set.cc_pair_ids,
|
||||
"is_public": document_set.is_public,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
FROM python:3.11-slim-bookworm@sha256:9c6f90801e6b68e772b7c0ca74260cbf7af9f320acec894e26fccdaccfbe3b47
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
@@ -159,3 +161,58 @@ def test_removing_connector(
|
||||
doc_set_names=[],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
|
||||
def test_renaming_document_set(
|
||||
reset: None, # noqa: ARG001
|
||||
vespa_client: vespa_fixture,
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair.documents = DocumentManager.seed_dummy_docs(
|
||||
cc_pair=cc_pair,
|
||||
num_docs=NUM_DOCS,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
original_name = f"original_doc_set_{uuid4()}"
|
||||
doc_set = DocumentSetManager.create(
|
||||
name=original_name,
|
||||
cc_pair_ids=[cc_pair.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
DocumentSetManager.wait_for_sync(user_performing_action=admin_user)
|
||||
DocumentSetManager.verify(
|
||||
document_set=doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
new_name = f"renamed_doc_set_{uuid4()}"
|
||||
doc_set.name = new_name
|
||||
DocumentSetManager.edit(
|
||||
doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
DocumentSetManager.wait_for_sync(user_performing_action=admin_user)
|
||||
DocumentSetManager.verify(
|
||||
document_set=doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair,
|
||||
doc_set_names=[new_name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Integration tests verifying the knowledge_sources field on MinimalPersonaSnapshot.
|
||||
|
||||
The GET /persona endpoint returns MinimalPersonaSnapshot, which includes a
|
||||
knowledge_sources list derived from the persona's document sets, hierarchy
|
||||
nodes, attached documents, and user files. These tests verify that the
|
||||
field is populated correctly.
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_file_utils import create_test_text_file
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_minimal_persona(
|
||||
persona_id: int,
|
||||
user: DATestUser,
|
||||
) -> dict:
|
||||
"""Fetch personas from the list endpoint and find the one with the given id."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
params={"persona_ids": persona_id},
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
personas = response.json()
|
||||
matches = [p for p in personas if p["id"] == persona_id]
|
||||
assert (
|
||||
len(matches) == 1
|
||||
), f"Expected 1 persona with id={persona_id}, got {len(matches)}"
|
||||
return matches[0]
|
||||
|
||||
|
||||
def test_persona_with_user_files_includes_user_file_source(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a persona has user files attached, knowledge_sources includes 'user_file'."""
|
||||
text_file = create_test_text_file("test content for knowledge source verification")
|
||||
file_descriptors, error = FileManager.upload_files(
|
||||
files=[("test_ks.txt", text_file)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert not error, f"File upload failed: {error}"
|
||||
|
||||
user_file_id = file_descriptors[0]["user_file_id"] or ""
|
||||
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="KS User File Agent",
|
||||
description="Agent with user files for knowledge_sources test",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_file_ids=[user_file_id],
|
||||
)
|
||||
|
||||
minimal = _get_minimal_persona(persona.id, admin_user)
|
||||
assert (
|
||||
DocumentSource.USER_FILE.value in minimal["knowledge_sources"]
|
||||
), f"Expected 'user_file' in knowledge_sources, got: {minimal['knowledge_sources']}"
|
||||
|
||||
|
||||
def test_persona_without_user_files_excludes_user_file_source(
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When a persona has no user files, knowledge_sources should not contain 'user_file'."""
|
||||
persona = PersonaManager.create(
|
||||
user_performing_action=admin_user,
|
||||
name="KS No Files Agent",
|
||||
description="Agent without files for knowledge_sources test",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
minimal = _get_minimal_persona(persona.id, admin_user)
|
||||
assert (
|
||||
DocumentSource.USER_FILE.value not in minimal["knowledge_sources"]
|
||||
), f"Unexpected 'user_file' in knowledge_sources: {minimal['knowledge_sources']}"
|
||||
@@ -301,7 +301,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -332,7 +331,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -363,7 +361,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -391,7 +388,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -423,7 +419,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -456,7 +451,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -497,7 +491,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -519,7 +512,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -542,7 +534,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -596,7 +587,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -653,7 +643,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -706,7 +695,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -736,7 +724,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
|
||||
0
backend/tests/unit/onyx/configs/__init__.py
Normal file
0
backend/tests/unit/onyx/configs/__init__.py
Normal file
88
backend/tests/unit/onyx/configs/test_sentry.py
Normal file
88
backend/tests/unit/onyx/configs/test_sentry.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from sentry_sdk.types import Event
|
||||
|
||||
import onyx.configs.sentry as sentry_module
|
||||
from onyx.configs.sentry import _add_instance_tags
|
||||
|
||||
|
||||
def _event(data: dict) -> Event:
|
||||
"""Helper to create a Sentry Event from a plain dict for testing."""
|
||||
return cast(Event, data)
|
||||
|
||||
|
||||
def _reset_state() -> None:
|
||||
"""Reset the module-level resolved flag between tests."""
|
||||
sentry_module._instance_id_resolved = False
|
||||
|
||||
|
||||
class TestAddInstanceTags:
|
||||
def setup_method(self) -> None:
|
||||
_reset_state()
|
||||
|
||||
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_first_event_sets_instance_id(
|
||||
self, mock_set_tag: MagicMock, mock_uuid: MagicMock
|
||||
) -> None:
|
||||
result = _add_instance_tags(_event({"message": "test error"}), {})
|
||||
|
||||
assert result is not None
|
||||
assert result["tags"]["instance_id"] == "test-uuid-1234"
|
||||
mock_set_tag.assert_called_once_with("instance_id", "test-uuid-1234")
|
||||
mock_uuid.assert_called_once()
|
||||
|
||||
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_second_event_skips_resolution(
|
||||
self, _mock_set_tag: MagicMock, mock_uuid: MagicMock
|
||||
) -> None:
|
||||
_add_instance_tags(_event({"message": "first"}), {})
|
||||
result = _add_instance_tags(_event({"message": "second"}), {})
|
||||
|
||||
assert result is not None
|
||||
assert "tags" not in result # second event not modified
|
||||
mock_uuid.assert_called_once() # only resolved once
|
||||
|
||||
@patch(
|
||||
"onyx.utils.telemetry.get_or_generate_uuid",
|
||||
side_effect=Exception("DB unavailable"),
|
||||
)
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_resolution_failure_still_returns_event(
|
||||
self, _mock_set_tag: MagicMock, _mock_uuid: MagicMock
|
||||
) -> None:
|
||||
result = _add_instance_tags(_event({"message": "test error"}), {})
|
||||
|
||||
assert result is not None
|
||||
assert result["message"] == "test error"
|
||||
assert "tags" not in result or "instance_id" not in result.get("tags", {})
|
||||
|
||||
@patch(
|
||||
"onyx.utils.telemetry.get_or_generate_uuid",
|
||||
side_effect=Exception("DB unavailable"),
|
||||
)
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_resolution_failure_retries_on_next_event(
|
||||
self, _mock_set_tag: MagicMock, mock_uuid: MagicMock
|
||||
) -> None:
|
||||
"""If resolution fails (e.g. DB not ready), retry on the next event."""
|
||||
_add_instance_tags(_event({"message": "first"}), {})
|
||||
_add_instance_tags(_event({"message": "second"}), {})
|
||||
|
||||
assert mock_uuid.call_count == 2 # retried on second event
|
||||
|
||||
@patch("onyx.utils.telemetry.get_or_generate_uuid", return_value="test-uuid-1234")
|
||||
@patch("sentry_sdk.set_tag")
|
||||
def test_preserves_existing_tags(
|
||||
self, _mock_set_tag: MagicMock, _mock_uuid: MagicMock
|
||||
) -> None:
|
||||
result = _add_instance_tags(
|
||||
_event({"message": "test", "tags": {"existing": "tag"}}), {}
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["tags"]["existing"] == "tag"
|
||||
assert result["tags"]["instance_id"] == "test-uuid-1234"
|
||||
@@ -6,9 +6,8 @@ import pytest
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
_DEFAULT_PAGINATION_LIMIT,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import _DEFAULT_PAGINATION_LIMIT
|
||||
from onyx.connectors.confluence.onyx_confluence import _MINIMUM_PAGINATION_LIMIT
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
|
||||
@@ -96,8 +95,7 @@ def test_cql_paginate_all_expansions_handles_internal_pagination_error(
|
||||
"""
|
||||
caplog.set_level("WARNING") # To check logging messages
|
||||
|
||||
# Use constants from the client instance, but note the test logic goes below MINIMUM
|
||||
_TEST_MINIMUM_LIMIT = 1 # The limit this test expects the retry to reach
|
||||
_TEST_MINIMUM_LIMIT = 1 # The one-by-one fallback always uses limit=1
|
||||
|
||||
top_level_cql = "test_cql"
|
||||
top_level_expand = "child_items"
|
||||
@@ -170,15 +168,21 @@ def test_cql_paginate_all_expansions_handles_internal_pagination_error(
|
||||
url=exp1_page2_path,
|
||||
)
|
||||
|
||||
# Problematic Expansion 2 URLs and Errors during limit reduction
|
||||
# Problematic Expansion 2 URLs and Errors during limit reduction.
|
||||
# Limit halves from _DEFAULT_PAGINATION_LIMIT down to _MINIMUM_PAGINATION_LIMIT,
|
||||
# then the one-by-one fallback kicks in at limit=1.
|
||||
exp2_base_path = "rest/api/content/2/child"
|
||||
exp2_reduction_errors = {}
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
while limit > _TEST_MINIMUM_LIMIT: # Reduce all the way to 1 for the test
|
||||
while limit >= _MINIMUM_PAGINATION_LIMIT:
|
||||
path = f"{exp2_base_path}?limit={limit}"
|
||||
exp2_reduction_errors[path] = _create_http_error(500, url=path)
|
||||
new_limit = limit // 2
|
||||
limit = max(new_limit, _TEST_MINIMUM_LIMIT) # Ensure it hits 1
|
||||
limit = max(new_limit, _MINIMUM_PAGINATION_LIMIT)
|
||||
if limit == _MINIMUM_PAGINATION_LIMIT and path.endswith(
|
||||
f"limit={_MINIMUM_PAGINATION_LIMIT}"
|
||||
):
|
||||
break
|
||||
|
||||
# Expansion 2 - Pagination at Limit = 1 (2 successes, 2 failures)
|
||||
exp2_limit1_page1_path = f"{exp2_base_path}?limit={_TEST_MINIMUM_LIMIT}&start=0"
|
||||
@@ -320,10 +324,14 @@ def test_cql_paginate_all_expansions_handles_internal_pagination_error(
|
||||
mock_get_call_paths[3] == f"{exp2_base_path}?limit={_DEFAULT_PAGINATION_LIMIT}"
|
||||
)
|
||||
|
||||
# 5+. Expansion 2 (retries due to 500s, down to limit=1)
|
||||
call_index = 4
|
||||
# 5+. Expansion 2 (limit reduction retries due to 500s, down to _MINIMUM_PAGINATION_LIMIT)
|
||||
# Then one-by-one fallback at limit=1
|
||||
num_reduction_steps = (
|
||||
len(exp2_reduction_errors) - 1
|
||||
) # first was already counted at index 3
|
||||
call_index = 4 + num_reduction_steps
|
||||
|
||||
# 5+N. Expansion 2 (limit=1, page 1 success)
|
||||
# Next: one-by-one fallback (limit=1, page 1 success)
|
||||
assert mock_get_call_paths[call_index] == exp2_limit1_page1_path
|
||||
call_index += 1
|
||||
# 5+N+1. Expansion 2 (limit=1, page 2 success)
|
||||
@@ -680,16 +688,16 @@ def test_paginated_cql_retrieval_skips_completely_failing_page(
|
||||
assert mock_get_call_paths == expected_calls
|
||||
|
||||
|
||||
def test_paginated_cql_retrieval_cloud_no_retry_on_error(
|
||||
def test_paginated_cql_retrieval_cloud_reduces_limit_on_error(
|
||||
mock_credentials_provider: mock.Mock,
|
||||
) -> None:
|
||||
"""
|
||||
Tests that for Confluence Cloud (is_cloud=True), paginated_cql_retrieval
|
||||
does NOT retry on pagination errors and raises HTTPError immediately.
|
||||
progressively halves the limit on server errors (500/504) and eventually
|
||||
raises once the limit floor is reached.
|
||||
"""
|
||||
# Setup Confluence Cloud Client
|
||||
confluence_cloud_client = OnyxConfluence(
|
||||
is_cloud=True, # Key difference: Cloud instance
|
||||
is_cloud=True,
|
||||
url="https://fake-cloud.atlassian.net",
|
||||
credentials_provider=mock_credentials_provider,
|
||||
timeout=10,
|
||||
@@ -701,28 +709,23 @@ def test_paginated_cql_retrieval_cloud_no_retry_on_error(
|
||||
|
||||
test_cql = "type=page"
|
||||
encoded_cql = "type%3Dpage"
|
||||
test_limit = 50 # Use a standard limit
|
||||
# Start with a small limit so the halving chain is short:
|
||||
# 10 -> 5 (== _MINIMUM_PAGINATION_LIMIT) -> raise
|
||||
test_limit = 10
|
||||
|
||||
base_path = f"rest/api/content/search?cql={encoded_cql}"
|
||||
page1_path = f"{base_path}&limit={test_limit}"
|
||||
page2_path = f"{base_path}&limit={test_limit}&start={test_limit}"
|
||||
|
||||
# --- Mock Responses ---
|
||||
# Page 1: Success
|
||||
page1_response = _create_mock_response(
|
||||
200,
|
||||
{
|
||||
"results": [{"id": i} for i in range(test_limit)],
|
||||
"_links": {"next": f"/{page2_path}"},
|
||||
"_links": {"next": f"/{base_path}&limit={test_limit}&start={test_limit}"},
|
||||
"size": test_limit,
|
||||
},
|
||||
url=page1_path,
|
||||
)
|
||||
|
||||
# Page 2: Failure (500)
|
||||
page2_error = _create_http_error(500, url=page2_path)
|
||||
|
||||
# --- Side Effect Logic ---
|
||||
mock_get_call_paths: list[str] = []
|
||||
|
||||
def get_side_effect(
|
||||
@@ -732,24 +735,14 @@ def test_paginated_cql_retrieval_cloud_no_retry_on_error(
|
||||
) -> requests.Response:
|
||||
path = path.strip("/")
|
||||
mock_get_call_paths.append(path)
|
||||
print(f"Mock GET received path: {path}")
|
||||
|
||||
if path == page1_path:
|
||||
print(f"-> Returning page 1 success for {path}")
|
||||
if "limit=10" in path and "start=" not in path:
|
||||
return page1_response
|
||||
elif path == page2_path:
|
||||
print(f"-> Returning page 2 500 error for {path}")
|
||||
return page2_error
|
||||
else:
|
||||
# No other paths (like limit=1 retries) should be called
|
||||
print(f"!!! Unexpected GET path in mock for Cloud test: {path}")
|
||||
raise RuntimeError(f"Unexpected GET path in mock for Cloud test: {path}")
|
||||
# Every subsequent call (including reduced-limit retries) returns 500
|
||||
return _create_http_error(500, url=path)
|
||||
|
||||
confluence_cloud_client._confluence.get.side_effect = get_side_effect
|
||||
|
||||
# --- Execute & Assert ---
|
||||
with pytest.raises(HTTPError) as excinfo:
|
||||
# Consume the iterator to trigger calls
|
||||
with pytest.raises(HTTPError):
|
||||
list(
|
||||
confluence_cloud_client.paginated_cql_retrieval(
|
||||
cql=test_cql,
|
||||
@@ -757,11 +750,240 @@ def test_paginated_cql_retrieval_cloud_no_retry_on_error(
|
||||
)
|
||||
)
|
||||
|
||||
# Verify the error is the one we simulated for page 2
|
||||
assert excinfo.value.response == page2_error
|
||||
assert excinfo.value.response.status_code == 500
|
||||
assert page2_path in excinfo.value.response.url
|
||||
# First call succeeds (limit=10), then page 2 at limit=10 fails,
|
||||
# retry at limit=5 fails, and since 5 == _MINIMUM_PAGINATION_LIMIT it raises.
|
||||
assert len(mock_get_call_paths) == 3
|
||||
assert f"limit={test_limit}" in mock_get_call_paths[0]
|
||||
assert f"limit={test_limit}" in mock_get_call_paths[1]
|
||||
assert f"limit={_MINIMUM_PAGINATION_LIMIT}" in mock_get_call_paths[2]
|
||||
|
||||
# Verify only two calls were made (page 1 success, page 2 fail)
|
||||
# Crucially, no retry attempts with different limits should exist.
|
||||
assert mock_get_call_paths == [page1_path, page2_path]
|
||||
|
||||
def test_paginate_url_reduces_limit_on_504_cloud(
|
||||
mock_credentials_provider: mock.Mock,
|
||||
) -> None:
|
||||
"""
|
||||
On Cloud, a 504 on the first page triggers limit halving. Once the request
|
||||
succeeds at the reduced limit, pagination continues at that limit and
|
||||
yields all results.
|
||||
"""
|
||||
client = OnyxConfluence(
|
||||
is_cloud=True,
|
||||
url="https://fake-cloud.atlassian.net",
|
||||
credentials_provider=mock_credentials_provider,
|
||||
timeout=10,
|
||||
)
|
||||
mock_internal = mock.Mock()
|
||||
mock_internal.url = client._url
|
||||
client._confluence = mock_internal
|
||||
client._kwargs = client.shared_base_kwargs
|
||||
|
||||
test_limit = 20
|
||||
|
||||
mock_get_call_paths: list[str] = []
|
||||
|
||||
def get_side_effect(
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None, # noqa: ARG001
|
||||
advanced_mode: bool = False, # noqa: ARG001
|
||||
) -> requests.Response:
|
||||
path = path.strip("/")
|
||||
mock_get_call_paths.append(path)
|
||||
|
||||
if f"limit={test_limit}" in path:
|
||||
return _create_http_error(504, url=path)
|
||||
|
||||
reduced_limit = test_limit // 2
|
||||
if f"limit={reduced_limit}" in path and "start=" not in path:
|
||||
return _create_mock_response(
|
||||
200,
|
||||
{
|
||||
"results": [{"id": 1}, {"id": 2}],
|
||||
"_links": {
|
||||
"next": f"/rest/api/content/search?cql=type%3Dpage&limit={test_limit}&start=2"
|
||||
},
|
||||
"size": 2,
|
||||
},
|
||||
url=path,
|
||||
)
|
||||
|
||||
if f"limit={reduced_limit}" in path and "start=2" in path:
|
||||
return _create_mock_response(
|
||||
200,
|
||||
{"results": [{"id": 3}], "_links": {}, "size": 1},
|
||||
url=path,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unexpected path: {path}")
|
||||
|
||||
client._confluence.get.side_effect = get_side_effect
|
||||
|
||||
results = list(client.paginated_cql_retrieval(cql="type=page", limit=test_limit))
|
||||
|
||||
assert [r["id"] for r in results] == [1, 2, 3]
|
||||
assert len(mock_get_call_paths) == 3
|
||||
assert f"limit={test_limit}" in mock_get_call_paths[0]
|
||||
assert f"limit={test_limit // 2}" in mock_get_call_paths[1]
|
||||
# The next-page URL had the old limit but should be rewritten to reduced
|
||||
assert f"limit={test_limit // 2}" in mock_get_call_paths[2]
|
||||
|
||||
|
||||
def test_paginate_url_reduces_limit_on_500_server(
|
||||
confluence_server_client: OnyxConfluence,
|
||||
) -> None:
|
||||
"""
|
||||
On Server, a 500 triggers limit halving first. If the reduced limit
|
||||
succeeds, results are yielded normally.
|
||||
"""
|
||||
test_limit = 20
|
||||
|
||||
mock_get_call_paths: list[str] = []
|
||||
|
||||
def get_side_effect(
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None, # noqa: ARG001
|
||||
advanced_mode: bool = False, # noqa: ARG001
|
||||
) -> requests.Response:
|
||||
path = path.strip("/")
|
||||
mock_get_call_paths.append(path)
|
||||
|
||||
if f"limit={test_limit}" in path:
|
||||
return _create_http_error(500, url=path)
|
||||
|
||||
reduced_limit = test_limit // 2
|
||||
if f"limit={reduced_limit}" in path:
|
||||
return _create_mock_response(
|
||||
200,
|
||||
{"results": [{"id": 1}, {"id": 2}], "_links": {}, "size": 2},
|
||||
url=path,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unexpected path: {path}")
|
||||
|
||||
confluence_server_client._confluence.get.side_effect = get_side_effect
|
||||
|
||||
results = list(
|
||||
confluence_server_client.paginated_cql_retrieval(
|
||||
cql="type=page", limit=test_limit
|
||||
)
|
||||
)
|
||||
|
||||
assert [r["id"] for r in results] == [1, 2]
|
||||
assert f"limit={test_limit}" in mock_get_call_paths[0]
|
||||
assert f"limit={test_limit // 2}" in mock_get_call_paths[1]
|
||||
|
||||
|
||||
def test_paginate_url_server_falls_back_to_one_by_one_after_limit_floor(
|
||||
confluence_server_client: OnyxConfluence,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""
|
||||
On Server, when limit reduction is exhausted (reaches the floor) and the
|
||||
request still fails, the one-by-one fallback kicks in.
|
||||
"""
|
||||
caplog.set_level("WARNING")
|
||||
# Start at the minimum so limit reduction is immediately exhausted
|
||||
test_limit = _MINIMUM_PAGINATION_LIMIT
|
||||
|
||||
mock_get_call_paths: list[str] = []
|
||||
|
||||
def get_side_effect(
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None, # noqa: ARG001
|
||||
advanced_mode: bool = False, # noqa: ARG001
|
||||
) -> requests.Response:
|
||||
path = path.strip("/")
|
||||
mock_get_call_paths.append(path)
|
||||
|
||||
if f"limit={test_limit}" in path and "start=" not in path:
|
||||
return _create_http_error(500, url=path)
|
||||
|
||||
# One-by-one fallback calls (limit=1)
|
||||
if "limit=1" in path:
|
||||
if "start=0" in path:
|
||||
return _create_mock_response(
|
||||
200,
|
||||
{"results": [{"id": 1}], "_links": {}, "size": 1},
|
||||
url=path,
|
||||
)
|
||||
if "start=1" in path:
|
||||
return _create_mock_response(
|
||||
200,
|
||||
{"results": [{"id": 2}], "_links": {}, "size": 1},
|
||||
url=path,
|
||||
)
|
||||
# start=2 onward: empty -> signals end
|
||||
return _create_mock_response(
|
||||
200,
|
||||
{"results": [], "_links": {}, "size": 0},
|
||||
url=path,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unexpected path: {path}")
|
||||
|
||||
confluence_server_client._confluence.get.side_effect = get_side_effect
|
||||
|
||||
results = list(
|
||||
confluence_server_client.paginated_cql_retrieval(
|
||||
cql="type=page", limit=test_limit
|
||||
)
|
||||
)
|
||||
|
||||
assert [r["id"] for r in results] == [1, 2]
|
||||
# First call at test_limit fails, then one-by-one at start=0,1,2
|
||||
one_by_one_calls = [p for p in mock_get_call_paths if "limit=1" in p]
|
||||
assert len(one_by_one_calls) >= 2
|
||||
|
||||
|
||||
def test_paginate_url_504_halves_multiple_times(
|
||||
mock_credentials_provider: mock.Mock,
|
||||
) -> None:
|
||||
"""
|
||||
Verifies that the limit is halved repeatedly on consecutive 504s until
|
||||
the request finally succeeds at a smaller limit.
|
||||
"""
|
||||
client = OnyxConfluence(
|
||||
is_cloud=True,
|
||||
url="https://fake-cloud.atlassian.net",
|
||||
credentials_provider=mock_credentials_provider,
|
||||
timeout=10,
|
||||
)
|
||||
mock_internal = mock.Mock()
|
||||
mock_internal.url = client._url
|
||||
client._confluence = mock_internal
|
||||
client._kwargs = client.shared_base_kwargs
|
||||
|
||||
test_limit = 40
|
||||
# 40 -> 20 (504) -> 10 (504) -> 5 (success)
|
||||
|
||||
mock_get_call_paths: list[str] = []
|
||||
|
||||
def get_side_effect(
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None, # noqa: ARG001
|
||||
advanced_mode: bool = False, # noqa: ARG001
|
||||
) -> requests.Response:
|
||||
path = path.strip("/")
|
||||
mock_get_call_paths.append(path)
|
||||
|
||||
if "limit=40" in path or "limit=20" in path or "limit=10" in path:
|
||||
return _create_http_error(504, url=path)
|
||||
|
||||
if "limit=5" in path:
|
||||
return _create_mock_response(
|
||||
200,
|
||||
{"results": [{"id": 99}], "_links": {}, "size": 1},
|
||||
url=path,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unexpected path: {path}")
|
||||
|
||||
client._confluence.get.side_effect = get_side_effect
|
||||
|
||||
results = list(client.paginated_cql_retrieval(cql="type=page", limit=test_limit))
|
||||
|
||||
assert [r["id"] for r in results] == [99]
|
||||
assert len(mock_get_call_paths) == 4
|
||||
assert "limit=40" in mock_get_call_paths[0]
|
||||
assert "limit=20" in mock_get_call_paths[1]
|
||||
assert "limit=10" in mock_get_call_paths[2]
|
||||
assert "limit=5" in mock_get_call_paths[3]
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Tests verifying that GithubConnector implements SlimConnector and SlimConnectorWithPermSync
|
||||
correctly, and that pruning uses the cheap slim path (no lazy loading).
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import SlimDocument
|
||||
|
||||
|
||||
def _make_pr(html_url: str) -> MagicMock:
|
||||
pr = MagicMock()
|
||||
pr.html_url = html_url
|
||||
pr.pull_request = None
|
||||
# commits and changed_files should never be accessed during slim retrieval
|
||||
pr.commits = property(
|
||||
lambda _: (_ for _ in ()).throw(AssertionError("lazy load triggered"))
|
||||
)
|
||||
pr.changed_files = property(
|
||||
lambda _: (_ for _ in ()).throw(AssertionError("lazy load triggered"))
|
||||
)
|
||||
return pr
|
||||
|
||||
|
||||
def _make_issue(html_url: str) -> MagicMock:
|
||||
issue = MagicMock()
|
||||
issue.html_url = html_url
|
||||
issue.pull_request = None
|
||||
return issue
|
||||
|
||||
|
||||
def _make_connector(include_issues: bool = False) -> GithubConnector:
|
||||
connector = GithubConnector(
|
||||
repo_owner="test-org",
|
||||
repositories="test-repo",
|
||||
include_prs=True,
|
||||
include_issues=include_issues,
|
||||
)
|
||||
connector.github_client = MagicMock()
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_repo() -> MagicMock:
|
||||
repo = MagicMock()
|
||||
repo.name = "test-repo"
|
||||
prs = [
|
||||
_make_pr(f"https://github.com/test-org/test-repo/pull/{i}") for i in range(1, 4)
|
||||
]
|
||||
repo.get_pulls.return_value = prs
|
||||
return repo
|
||||
|
||||
|
||||
def test_github_connector_implements_slim_connector() -> None:
|
||||
connector = _make_connector()
|
||||
assert isinstance(connector, SlimConnector)
|
||||
|
||||
|
||||
def test_github_connector_implements_slim_connector_with_perm_sync() -> None:
|
||||
connector = _make_connector()
|
||||
assert isinstance(connector, SlimConnectorWithPermSync)
|
||||
|
||||
|
||||
def test_retrieve_all_slim_docs_returns_pr_urls(mock_repo: MagicMock) -> None:
|
||||
connector = _make_connector()
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
batches = list(connector.retrieve_all_slim_docs())
|
||||
|
||||
all_docs = [doc for batch in batches for doc in batch]
|
||||
assert len(all_docs) == 3
|
||||
assert all(isinstance(doc, SlimDocument) for doc in all_docs)
|
||||
assert {doc.id for doc in all_docs} == {
|
||||
"https://github.com/test-org/test-repo/pull/1",
|
||||
"https://github.com/test-org/test-repo/pull/2",
|
||||
"https://github.com/test-org/test-repo/pull/3",
|
||||
}
|
||||
|
||||
|
||||
def test_retrieve_all_slim_docs_has_no_external_access(mock_repo: MagicMock) -> None:
|
||||
"""Pruning does not need permissions — external_access should be None."""
|
||||
connector = _make_connector()
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
batches = list(connector.retrieve_all_slim_docs())
|
||||
|
||||
all_docs = [doc for batch in batches for doc in batch]
|
||||
assert all(doc.external_access is None for doc in all_docs)
|
||||
|
||||
|
||||
def test_retrieve_all_slim_docs_perm_sync_populates_external_access(
|
||||
mock_repo: MagicMock,
|
||||
) -> None:
|
||||
connector = _make_connector()
|
||||
mock_access = MagicMock(spec=ExternalAccess)
|
||||
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
with patch(
|
||||
"onyx.connectors.github.connector.get_external_access_permission",
|
||||
return_value=mock_access,
|
||||
) as mock_perm:
|
||||
batches = list(connector.retrieve_all_slim_docs_perm_sync())
|
||||
|
||||
# permission fetched exactly once per repo
|
||||
mock_perm.assert_called_once_with(mock_repo, connector.github_client)
|
||||
|
||||
all_docs = [doc for batch in batches for doc in batch]
|
||||
assert all(doc.external_access is mock_access for doc in all_docs)
|
||||
|
||||
|
||||
def test_retrieve_all_slim_docs_skips_pr_issues(mock_repo: MagicMock) -> None:
|
||||
"""Issues that are actually PRs should be skipped when include_issues=True."""
|
||||
connector = _make_connector(include_issues=True)
|
||||
|
||||
pr_issue = MagicMock()
|
||||
pr_issue.html_url = "https://github.com/test-org/test-repo/pull/99"
|
||||
pr_issue.pull_request = MagicMock() # non-None means it's a PR
|
||||
|
||||
real_issue = _make_issue("https://github.com/test-org/test-repo/issues/1")
|
||||
mock_repo.get_issues.return_value = [pr_issue, real_issue]
|
||||
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
batches = list(connector.retrieve_all_slim_docs())
|
||||
|
||||
issue_ids = {doc.id for batch in batches for doc in batch if "issues" in doc.id}
|
||||
assert issue_ids == {"https://github.com/test-org/test-repo/issues/1"}
|
||||
|
||||
|
||||
def test_pruning_routes_to_slim_connector_path(mock_repo: MagicMock) -> None:
|
||||
"""extract_ids_from_runnable_connector must use SlimConnector, not CheckpointedConnector."""
|
||||
connector = _make_connector()
|
||||
|
||||
with patch.object(connector, "fetch_configured_repos", return_value=[mock_repo]):
|
||||
# If the CheckpointedConnector fallback were used instead, it would call
|
||||
# load_from_checkpoint which hits _convert_pr_to_document and lazy loads.
|
||||
# We verify the slim path is taken by checking load_from_checkpoint is NOT called.
|
||||
with patch.object(connector, "load_from_checkpoint") as mock_load:
|
||||
result = extract_ids_from_runnable_connector(connector)
|
||||
mock_load.assert_not_called()
|
||||
|
||||
assert len(result.raw_id_to_parent) == 3
|
||||
assert "https://github.com/test-org/test-repo/pull/1" in result.raw_id_to_parent
|
||||
@@ -0,0 +1,200 @@
|
||||
"""Unit tests for GoogleDriveConnector slim retrieval routing.
|
||||
|
||||
Verifies that:
|
||||
- GoogleDriveConnector implements SlimConnector so pruning takes the ID-only path
|
||||
- retrieve_all_slim_docs() calls _extract_slim_docs_from_google_drive with include_permissions=False
|
||||
- retrieve_all_slim_docs_perm_sync() calls _extract_slim_docs_from_google_drive with include_permissions=True
|
||||
- celery_utils routing picks retrieve_all_slim_docs() for GoogleDriveConnector
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
|
||||
def _make_done_checkpoint() -> GoogleDriveCheckpoint:
|
||||
return GoogleDriveCheckpoint(
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
|
||||
def _make_connector() -> GoogleDriveConnector:
|
||||
connector = GoogleDriveConnector(include_my_drives=True)
|
||||
connector._creds = MagicMock()
|
||||
connector._primary_admin_email = "admin@example.com"
|
||||
return connector
|
||||
|
||||
|
||||
class TestGoogleDriveSlimConnectorInterface:
|
||||
def test_implements_slim_connector(self) -> None:
|
||||
connector = _make_connector()
|
||||
assert isinstance(connector, SlimConnector)
|
||||
|
||||
def test_implements_slim_connector_with_perm_sync(self) -> None:
|
||||
connector = _make_connector()
|
||||
assert isinstance(connector, SlimConnectorWithPermSync)
|
||||
|
||||
def test_slim_connector_checked_before_perm_sync(self) -> None:
|
||||
"""SlimConnector must appear before SlimConnectorWithPermSync in MRO
|
||||
so celery_utils isinstance check routes to retrieve_all_slim_docs()."""
|
||||
mro = GoogleDriveConnector.__mro__
|
||||
slim_idx = mro.index(SlimConnector)
|
||||
perm_sync_idx = mro.index(SlimConnectorWithPermSync)
|
||||
assert slim_idx < perm_sync_idx
|
||||
|
||||
|
||||
class TestRetrieveAllSlimDocs:
|
||||
def test_does_not_call_extract_when_checkpoint_is_done(self) -> None:
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector, "build_dummy_checkpoint", return_value=_make_done_checkpoint()
|
||||
):
|
||||
with patch.object(
|
||||
connector,
|
||||
"_extract_slim_docs_from_google_drive",
|
||||
return_value=iter([[slim_doc]]),
|
||||
) as mock_extract:
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
|
||||
mock_extract.assert_not_called() # loop exits immediately since checkpoint is DONE
|
||||
|
||||
def test_calls_extract_with_include_permissions_false_non_done_checkpoint(
|
||||
self,
|
||||
) -> None:
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
# Checkpoint starts at START, _extract advances it to DONE
|
||||
with patch.object(connector, "build_dummy_checkpoint") as mock_build:
|
||||
start_checkpoint = GoogleDriveCheckpoint(
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.START,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=False,
|
||||
)
|
||||
mock_build.return_value = start_checkpoint
|
||||
|
||||
def _advance_checkpoint(**_kwargs: object) -> object:
|
||||
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
yield [slim_doc]
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_extract_slim_docs_from_google_drive",
|
||||
side_effect=_advance_checkpoint,
|
||||
) as mock_extract:
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
|
||||
mock_extract.assert_called_once()
|
||||
_, kwargs = mock_extract.call_args
|
||||
assert kwargs.get("include_permissions") is False
|
||||
|
||||
def test_yields_slim_documents(self) -> None:
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
start_checkpoint = GoogleDriveCheckpoint(
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.START,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector, "build_dummy_checkpoint", return_value=start_checkpoint
|
||||
):
|
||||
|
||||
def _advance_and_yield(**_kwargs: object) -> object:
|
||||
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
yield [slim_doc]
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_extract_slim_docs_from_google_drive",
|
||||
side_effect=_advance_and_yield,
|
||||
):
|
||||
batches = list(connector.retrieve_all_slim_docs())
|
||||
|
||||
assert len(batches) == 1
|
||||
assert batches[0][0] is slim_doc
|
||||
|
||||
|
||||
class TestRetrieveAllSlimDocsPermSync:
|
||||
def test_calls_extract_with_include_permissions_true(self) -> None:
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
start_checkpoint = GoogleDriveCheckpoint(
|
||||
retrieved_folder_and_drive_ids=set(),
|
||||
completion_stage=DriveRetrievalStage.START,
|
||||
completion_map=ThreadSafeDict(),
|
||||
all_retrieved_file_ids=set(),
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
connector, "build_dummy_checkpoint", return_value=start_checkpoint
|
||||
):
|
||||
|
||||
def _advance_and_yield(**_kwargs: object) -> object:
|
||||
start_checkpoint.completion_stage = DriveRetrievalStage.DONE
|
||||
yield [slim_doc]
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_extract_slim_docs_from_google_drive",
|
||||
side_effect=_advance_and_yield,
|
||||
) as mock_extract:
|
||||
list(connector.retrieve_all_slim_docs_perm_sync())
|
||||
|
||||
mock_extract.assert_called_once()
|
||||
_, kwargs = mock_extract.call_args
|
||||
assert (
|
||||
kwargs.get("include_permissions") is None
|
||||
or kwargs.get("include_permissions") is True
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryUtilsRouting:
|
||||
def test_pruning_uses_retrieve_all_slim_docs(self) -> None:
|
||||
"""extract_ids_from_runnable_connector must call retrieve_all_slim_docs,
|
||||
not retrieve_all_slim_docs_perm_sync, for GoogleDriveConnector."""
|
||||
connector = _make_connector()
|
||||
slim_doc = MagicMock(
|
||||
spec=SlimDocument, id="doc1", parent_hierarchy_raw_node_id=None
|
||||
)
|
||||
with (
|
||||
patch.object(
|
||||
connector, "retrieve_all_slim_docs", return_value=iter([[slim_doc]])
|
||||
) as mock_slim,
|
||||
patch.object(
|
||||
connector, "retrieve_all_slim_docs_perm_sync"
|
||||
) as mock_perm_sync,
|
||||
):
|
||||
extract_ids_from_runnable_connector(
|
||||
connector, connector_type="google_drive"
|
||||
)
|
||||
|
||||
mock_slim.assert_called_once()
|
||||
mock_perm_sync.assert_not_called()
|
||||
@@ -0,0 +1,182 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from onyx.connectors.google_utils.google_kv import get_auth_url
|
||||
from onyx.connectors.google_utils.google_kv import get_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import get_service_account_key
|
||||
from onyx.connectors.google_utils.google_kv import upsert_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import upsert_service_account_key
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleAppWebCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
|
||||
|
||||
def _make_app_creds() -> GoogleAppCredentials:
|
||||
return GoogleAppCredentials(
|
||||
web=GoogleAppWebCredentials(
|
||||
client_id="client-id.apps.googleusercontent.com",
|
||||
project_id="test-project",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_secret="secret",
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
javascript_origins=["https://example.com"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _make_service_account_key() -> GoogleServiceAccountKey:
|
||||
return GoogleServiceAccountKey(
|
||||
type="service_account",
|
||||
project_id="test-project",
|
||||
private_key_id="private-key-id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
|
||||
client_email="test@test-project.iam.gserviceaccount.com",
|
||||
client_id="123",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test",
|
||||
universe_domain="googleapis.com",
|
||||
)
|
||||
|
||||
|
||||
def test_upsert_google_app_cred_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_google_app_cred(_make_app_creds(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["web"]["client_id"] == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
def test_upsert_service_account_key_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_service_account_key(_make_service_account_key(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["project_id"] == "test-project"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_google_app_cred_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload: dict[str, Any] = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
creds = get_google_app_cred(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert creds.web.client_id == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_service_account_key_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
stored_value: object = (
|
||||
_make_service_account_key().model_dump(mode="json")
|
||||
if not legacy_string
|
||||
else _make_service_account_key().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
key = get_service_account_key(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert key.client_email == "test@test-project.iam.gserviceaccount.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_auth_url_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
stored_state: dict[str, object] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored_state["key"] = key
|
||||
stored_state["value"] = value
|
||||
stored_state["encrypt"] = encrypt
|
||||
|
||||
class _StubFlow:
|
||||
def authorization_url(self, prompt: str) -> tuple[str, None]:
|
||||
assert prompt == "consent"
|
||||
return "https://accounts.google.com/o/oauth2/auth?state=test-state", None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
def _from_client_config(
|
||||
_app_config: object, *, scopes: object, redirect_uri: object
|
||||
) -> _StubFlow:
|
||||
del scopes, redirect_uri
|
||||
return _StubFlow()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.InstalledAppFlow.from_client_config",
|
||||
_from_client_config,
|
||||
)
|
||||
|
||||
auth_url = get_auth_url(42, DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert auth_url.startswith("https://accounts.google.com")
|
||||
assert stored_state["value"] == {"value": "test-state"}
|
||||
assert stored_state["encrypt"] is True
|
||||
@@ -4,6 +4,7 @@ from typing import cast
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
|
||||
from onyx.file_processing.extract_file_text import xlsx_sheet_extraction
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
|
||||
|
||||
@@ -196,3 +197,136 @@ class TestXlsxToText:
|
||||
assert "r1c1" in lines[0] and "r1c2" in lines[0]
|
||||
assert "r2c1" in lines[1] and "r2c2" in lines[1]
|
||||
assert "r3c1" in lines[2] and "r3c2" in lines[2]
|
||||
|
||||
|
||||
class TestXlsxSheetExtraction:
|
||||
def test_one_tuple_per_sheet(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Revenue": [["Month", "Amount"], ["Jan", "100"]],
|
||||
"Expenses": [["Category", "Cost"], ["Rent", "500"]],
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 2
|
||||
# Order preserved from workbook sheet order
|
||||
titles = [title for _csv, title in sheets]
|
||||
assert titles == ["Revenue", "Expenses"]
|
||||
# Content present in the right tuple
|
||||
revenue_csv, _ = sheets[0]
|
||||
expenses_csv, _ = sheets[1]
|
||||
assert "Month" in revenue_csv
|
||||
assert "Jan" in revenue_csv
|
||||
assert "Category" in expenses_csv
|
||||
assert "Rent" in expenses_csv
|
||||
|
||||
def test_tuple_structure_is_csv_text_then_title(self) -> None:
|
||||
"""The tuple order is (csv_text, sheet_title) — pin it so callers
|
||||
that unpack positionally don't silently break."""
|
||||
xlsx = _make_xlsx({"MySheet": [["a", "b"]]})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
csv_text, title = sheets[0]
|
||||
assert title == "MySheet"
|
||||
assert "a" in csv_text
|
||||
assert "b" in csv_text
|
||||
|
||||
def test_empty_sheet_is_skipped(self) -> None:
|
||||
"""A sheet whose CSV output is empty/whitespace-only should NOT
|
||||
appear in the result — the `if csv_text.strip():` guard filters
|
||||
it out."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Data": [["a", "b"]],
|
||||
"Empty": [],
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
assert sheets[0][1] == "Data"
|
||||
|
||||
def test_empty_workbook_returns_empty_list(self) -> None:
|
||||
"""All sheets empty → empty list (not a list of empty tuples)."""
|
||||
xlsx = _make_xlsx({"Sheet1": [], "Sheet2": []})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert sheets == []
|
||||
|
||||
def test_single_sheet(self) -> None:
|
||||
xlsx = _make_xlsx({"Only": [["x", "y"], ["1", "2"]]})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
csv_text, title = sheets[0]
|
||||
assert title == "Only"
|
||||
assert "x" in csv_text
|
||||
assert "1" in csv_text
|
||||
|
||||
def test_bad_zip_returns_empty_list(self) -> None:
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
sheets = xlsx_sheet_extraction(bad_file, file_name="test.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_bad_zip_tilde_file_returns_empty_list(self) -> None:
|
||||
"""`~$`-prefixed files are Excel lock files; failure should log
|
||||
at debug (not warning) and still return []."""
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
sheets = xlsx_sheet_extraction(bad_file, file_name="~$temp.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_csv_content_matches_xlsx_to_text_per_sheet(self) -> None:
|
||||
"""For a single-sheet workbook, xlsx_to_text output should equal
|
||||
the csv_text from xlsx_sheet_extraction — they share the same
|
||||
per-sheet CSV-ification logic."""
|
||||
single_sheet_data = [["Name", "Age"], ["Alice", "30"]]
|
||||
expected_text = xlsx_to_text(_make_xlsx({"People": single_sheet_data}))
|
||||
|
||||
sheets = xlsx_sheet_extraction(_make_xlsx({"People": single_sheet_data}))
|
||||
assert len(sheets) == 1
|
||||
csv_text, title = sheets[0]
|
||||
assert title == "People"
|
||||
assert csv_text.strip() == expected_text.strip()
|
||||
|
||||
def test_commas_in_cells_are_quoted(self) -> None:
|
||||
xlsx = _make_xlsx({"S1": [["hello, world", "normal"]]})
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
csv_text, _ = sheets[0]
|
||||
assert '"hello, world"' in csv_text
|
||||
|
||||
def test_long_empty_row_run_capped_within_sheet(self) -> None:
|
||||
"""The matrix cleanup applies per-sheet: >2 empty rows collapse
|
||||
to 2, which keeps the sheet non-empty and it still appears in
|
||||
the result."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"S1": [
|
||||
["header"],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
["data"],
|
||||
]
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
assert len(sheets) == 1
|
||||
csv_text, _ = sheets[0]
|
||||
lines = csv_text.strip().split("\n")
|
||||
# header + 2 empty (capped) + data = 4 lines
|
||||
assert len(lines) == 4
|
||||
assert "header" in lines[0]
|
||||
assert "data" in lines[-1]
|
||||
|
||||
def test_sheet_title_with_special_chars_preserved(self) -> None:
|
||||
"""Spaces, punctuation, unicode in sheet titles are preserved
|
||||
verbatim — the title is used as a link anchor downstream."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Q1 Revenue (USD)": [["a", "b"]],
|
||||
"Données": [["c", "d"]],
|
||||
}
|
||||
)
|
||||
sheets = xlsx_sheet_extraction(xlsx)
|
||||
titles = [title for _csv, title in sheets]
|
||||
assert "Q1 Revenue (USD)" in titles
|
||||
assert "Données" in titles
|
||||
|
||||
787
backend/tests/unit/onyx/indexing/test_document_chunker.py
Normal file
787
backend/tests/unit/onyx/indexing/test_document_chunker.py
Normal file
@@ -0,0 +1,787 @@
|
||||
import pytest
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import SECTION_SEPARATOR
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SectionType
|
||||
from onyx.indexing.chunking import DocumentChunker
|
||||
from onyx.indexing.chunking import text_section_chunker as text_chunker_module
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
|
||||
|
||||
class CharTokenizer(BaseTokenizer):
|
||||
"""1 character == 1 token. Deterministic & trivial to reason about."""
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
return [ord(c) for c in string]
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return list(string)
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return "".join(chr(t) for t in tokens)
|
||||
|
||||
|
||||
# With a char-level tokenizer, each char is a token. 200 is comfortably
|
||||
# above BLURB_SIZE (128) so the blurb splitter won't get weird on small text.
|
||||
CHUNK_LIMIT = 200
|
||||
|
||||
|
||||
def _make_document_chunker(
|
||||
chunk_token_limit: int = CHUNK_LIMIT,
|
||||
) -> DocumentChunker:
|
||||
def token_counter(text: str) -> int:
|
||||
return len(text)
|
||||
|
||||
return DocumentChunker(
|
||||
tokenizer=CharTokenizer(),
|
||||
blurb_splitter=SentenceChunker(
|
||||
tokenizer_or_token_counter=token_counter,
|
||||
chunk_size=128,
|
||||
chunk_overlap=0,
|
||||
return_type="texts",
|
||||
),
|
||||
chunk_splitter=SentenceChunker(
|
||||
tokenizer_or_token_counter=token_counter,
|
||||
chunk_size=chunk_token_limit,
|
||||
chunk_overlap=0,
|
||||
return_type="texts",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_doc(
|
||||
sections: list[Section],
|
||||
title: str | None = "Test Doc",
|
||||
doc_id: str = "doc1",
|
||||
) -> IndexingDocument:
|
||||
return IndexingDocument(
|
||||
id=doc_id,
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=doc_id,
|
||||
title=title,
|
||||
metadata={},
|
||||
sections=[], # real sections unused — method reads processed_sections
|
||||
processed_sections=sections,
|
||||
)
|
||||
|
||||
|
||||
# --- Empty / degenerate input -------------------------------------------------
|
||||
|
||||
|
||||
def test_empty_processed_sections_returns_single_empty_safety_chunk() -> None:
|
||||
"""No sections at all should still yield one empty chunk (the
|
||||
`or not chunks` safety branch at the end)."""
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(sections=[])
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=[],
|
||||
title_prefix="TITLE\n",
|
||||
metadata_suffix_semantic="meta_sem",
|
||||
metadata_suffix_keyword="meta_kw",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == ""
|
||||
assert chunks[0].chunk_id == 0
|
||||
assert chunks[0].title_prefix == "TITLE\n"
|
||||
assert chunks[0].metadata_suffix_semantic == "meta_sem"
|
||||
assert chunks[0].metadata_suffix_keyword == "meta_kw"
|
||||
# safe default link offsets
|
||||
assert chunks[0].source_links == {0: ""}
|
||||
|
||||
|
||||
def test_empty_section_on_first_position_without_title_is_skipped() -> None:
|
||||
"""Doc has no title, first section has empty text — the guard
|
||||
`(not document.title or section_idx > 0)` means it IS skipped."""
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[Section(type=SectionType.TEXT, text="", link="l0")],
|
||||
title=None,
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
# skipped → no real content, but safety branch still yields 1 empty chunk
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == ""
|
||||
|
||||
|
||||
def test_empty_section_on_later_position_is_skipped_even_with_title() -> None:
|
||||
"""Index > 0 empty sections are skipped regardless of title."""
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text="Alpha.", link="l0"),
|
||||
Section(type=SectionType.TEXT, text="", link="l1"), # should be skipped
|
||||
Section(type=SectionType.TEXT, text="Beta.", link="l2"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert "Alpha." in chunks[0].content
|
||||
assert "Beta." in chunks[0].content
|
||||
# link offsets should only contain l0 and l2 (no l1)
|
||||
assert "l1" not in (chunks[0].source_links or {}).values()
|
||||
|
||||
|
||||
# --- Single text section ------------------------------------------------------
|
||||
|
||||
|
||||
def test_single_small_text_section_becomes_one_chunk() -> None:
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[Section(type=SectionType.TEXT, text="Hello world.", link="https://a")]
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="TITLE\n",
|
||||
metadata_suffix_semantic="ms",
|
||||
metadata_suffix_keyword="mk",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
chunk = chunks[0]
|
||||
assert chunk.content == "Hello world."
|
||||
assert chunk.source_links == {0: "https://a"}
|
||||
assert chunk.title_prefix == "TITLE\n"
|
||||
assert chunk.metadata_suffix_semantic == "ms"
|
||||
assert chunk.metadata_suffix_keyword == "mk"
|
||||
assert chunk.section_continuation is False
|
||||
assert chunk.image_file_id is None
|
||||
|
||||
|
||||
# --- Multiple text sections combined -----------------------------------------
|
||||
|
||||
|
||||
def test_multiple_small_sections_combine_into_one_chunk() -> None:
|
||||
dc = _make_document_chunker()
|
||||
sections = [
|
||||
Section(type=SectionType.TEXT, text="Part one.", link="l1"),
|
||||
Section(type=SectionType.TEXT, text="Part two.", link="l2"),
|
||||
Section(type=SectionType.TEXT, text="Part three.", link="l3"),
|
||||
]
|
||||
doc = _make_doc(sections=sections)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
expected = SECTION_SEPARATOR.join(["Part one.", "Part two.", "Part three."])
|
||||
assert chunks[0].content == expected
|
||||
|
||||
# link_offsets: indexed by shared_precompare_cleanup length of the
|
||||
# chunk_text *before* each section was appended.
|
||||
# "" -> "", len 0
|
||||
# "Part one." -> "partone", len 7
|
||||
# "Part one.\n\nPart two." -> "partoneparttwo", len 14
|
||||
assert chunks[0].source_links == {0: "l1", 7: "l2", 14: "l3"}
|
||||
|
||||
|
||||
def test_sections_overflow_into_second_chunk() -> None:
|
||||
"""Two sections that together exceed content_token_limit should
|
||||
finalize the first as one chunk and start a new one."""
|
||||
dc = _make_document_chunker()
|
||||
# char-level: 120 char section → 120 tokens. 2 of these plus separator
|
||||
# exceed a 200-token limit, forcing a flush.
|
||||
a = "A" * 120
|
||||
b = "B" * 120
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text=a, link="la"),
|
||||
Section(type=SectionType.TEXT, text=b, link="lb"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].content == a
|
||||
assert chunks[1].content == b
|
||||
# first chunk is not a continuation; second starts a new section → not either
|
||||
assert chunks[0].section_continuation is False
|
||||
assert chunks[1].section_continuation is False
|
||||
# chunk_ids should be sequential starting at 0
|
||||
assert chunks[0].chunk_id == 0
|
||||
assert chunks[1].chunk_id == 1
|
||||
# links routed appropriately
|
||||
assert chunks[0].source_links == {0: "la"}
|
||||
assert chunks[1].source_links == {0: "lb"}
|
||||
|
||||
|
||||
# --- Image section handling --------------------------------------------------
|
||||
|
||||
|
||||
def test_image_only_section_produces_single_chunk_with_image_id() -> None:
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(
|
||||
type=SectionType.IMAGE,
|
||||
text="summary of image",
|
||||
link="https://img",
|
||||
image_file_id="img-abc",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].image_file_id == "img-abc"
|
||||
assert chunks[0].content == "summary of image"
|
||||
assert chunks[0].source_links == {0: "https://img"}
|
||||
|
||||
|
||||
def test_image_section_flushes_pending_text_and_creates_its_own_chunk() -> None:
|
||||
"""A buffered text section followed by an image section:
|
||||
the pending text should be flushed first, then the image chunk."""
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text="Pending text.", link="ltext"),
|
||||
Section(
|
||||
type=SectionType.IMAGE,
|
||||
text="image summary",
|
||||
link="limage",
|
||||
image_file_id="img-1",
|
||||
),
|
||||
Section(type=SectionType.TEXT, text="Trailing text.", link="ltail"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 3
|
||||
|
||||
# 0: flushed pending text
|
||||
assert chunks[0].content == "Pending text."
|
||||
assert chunks[0].image_file_id is None
|
||||
assert chunks[0].source_links == {0: "ltext"}
|
||||
|
||||
# 1: image chunk
|
||||
assert chunks[1].content == "image summary"
|
||||
assert chunks[1].image_file_id == "img-1"
|
||||
assert chunks[1].source_links == {0: "limage"}
|
||||
|
||||
# 2: trailing text, started fresh after image
|
||||
assert chunks[2].content == "Trailing text."
|
||||
assert chunks[2].image_file_id is None
|
||||
assert chunks[2].source_links == {0: "ltail"}
|
||||
|
||||
|
||||
def test_image_section_without_link_gets_empty_links_dict() -> None:
|
||||
"""If an image section has no link, links param is {} (not {0: ""})."""
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(
|
||||
type=SectionType.IMAGE,
|
||||
text="img",
|
||||
link=None,
|
||||
image_file_id="img-xyz",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].image_file_id == "img-xyz"
|
||||
# to_doc_aware_chunk falls back to {0: ""} when given an empty dict
|
||||
assert chunks[0].source_links == {0: ""}
|
||||
|
||||
|
||||
# --- Oversized section splitting ---------------------------------------------
|
||||
|
||||
|
||||
def test_oversized_section_is_split_across_multiple_chunks() -> None:
|
||||
"""A section whose text exceeds content_token_limit should be passed
|
||||
through chunk_splitter and yield >1 chunks; only the first is not a
|
||||
continuation."""
|
||||
dc = _make_document_chunker()
|
||||
# Build a section whose char-count is well over CHUNK_LIMIT (200), made
|
||||
# of many short sentences so chonkie's SentenceChunker can split cleanly.
|
||||
section_text = (
|
||||
"Alpha beta gamma. Delta epsilon zeta. Eta theta iota. "
|
||||
"Kappa lambda mu. Nu xi omicron. Pi rho sigma. Tau upsilon phi. "
|
||||
"Chi psi omega. One two three. Four five six. Seven eight nine. "
|
||||
"Ten eleven twelve. Thirteen fourteen fifteen. "
|
||||
"Sixteen seventeen eighteen. Nineteen twenty."
|
||||
)
|
||||
assert len(section_text) > CHUNK_LIMIT
|
||||
|
||||
doc = _make_doc(
|
||||
sections=[Section(type=SectionType.TEXT, text=section_text, link="big-link")],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
# First chunk is fresh, rest are continuations
|
||||
assert chunks[0].section_continuation is False
|
||||
for c in chunks[1:]:
|
||||
assert c.section_continuation is True
|
||||
# Every produced chunk should carry the section's link
|
||||
for c in chunks:
|
||||
assert c.source_links == {0: "big-link"}
|
||||
# Concatenated content should roughly cover the original (allowing
|
||||
# for chunker boundary whitespace differences).
|
||||
joined = "".join(c.content for c in chunks)
|
||||
for word in ("Alpha", "omega", "twenty"):
|
||||
assert word in joined
|
||||
|
||||
|
||||
def test_oversized_section_flushes_pending_text_first() -> None:
|
||||
"""A buffered text section followed by an oversized section should
|
||||
flush the pending chunk first, then emit the split chunks."""
|
||||
dc = _make_document_chunker()
|
||||
pending = "Pending buffered text."
|
||||
big = (
|
||||
"Alpha beta gamma. Delta epsilon zeta. Eta theta iota. "
|
||||
"Kappa lambda mu. Nu xi omicron. Pi rho sigma. Tau upsilon phi. "
|
||||
"Chi psi omega. One two three. Four five six. Seven eight nine. "
|
||||
"Ten eleven twelve. Thirteen fourteen fifteen. Sixteen seventeen."
|
||||
)
|
||||
assert len(big) > CHUNK_LIMIT
|
||||
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text=pending, link="l-pending"),
|
||||
Section(type=SectionType.TEXT, text=big, link="l-big"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
# First chunk is the flushed pending text
|
||||
assert chunks[0].content == pending
|
||||
assert chunks[0].source_links == {0: "l-pending"}
|
||||
assert chunks[0].section_continuation is False
|
||||
|
||||
# Remaining chunks correspond to the oversized section
|
||||
assert len(chunks) >= 2
|
||||
for c in chunks[1:]:
|
||||
assert c.source_links == {0: "l-big"}
|
||||
# Within the oversized section, the first is fresh and the rest are
|
||||
# continuations.
|
||||
assert chunks[1].section_continuation is False
|
||||
for c in chunks[2:]:
|
||||
assert c.section_continuation is True
|
||||
|
||||
|
||||
# --- Title prefix / metadata propagation -------------------------------------
|
||||
|
||||
|
||||
def test_title_prefix_and_metadata_propagate_to_all_chunks() -> None:
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text="A" * 120, link="la"),
|
||||
Section(type=SectionType.TEXT, text="B" * 120, link="lb"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="MY_TITLE\n",
|
||||
metadata_suffix_semantic="MS",
|
||||
metadata_suffix_keyword="MK",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 2
|
||||
for chunk in chunks:
|
||||
assert chunk.title_prefix == "MY_TITLE\n"
|
||||
assert chunk.metadata_suffix_semantic == "MS"
|
||||
assert chunk.metadata_suffix_keyword == "MK"
|
||||
|
||||
|
||||
# --- chunk_id monotonicity ---------------------------------------------------
|
||||
|
||||
|
||||
def test_chunk_ids_are_sequential_starting_at_zero() -> None:
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text="A" * 120, link="la"),
|
||||
Section(type=SectionType.TEXT, text="B" * 120, link="lb"),
|
||||
Section(type=SectionType.TEXT, text="C" * 120, link="lc"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert [c.chunk_id for c in chunks] == list(range(len(chunks)))
|
||||
|
||||
|
||||
# --- Overflow accumulation behavior ------------------------------------------
|
||||
|
||||
|
||||
def test_overflow_flush_then_subsequent_section_joins_new_chunk() -> None:
|
||||
"""After an overflow flush starts a new chunk, the next fitting section
|
||||
should combine into that same new chunk (not spawn a third)."""
|
||||
dc = _make_document_chunker()
|
||||
# 120 + 120 > 200 → first two sections produce two chunks.
|
||||
# Third section is small (20 chars) → should fit with second.
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text="A" * 120, link="la"),
|
||||
Section(type=SectionType.TEXT, text="B" * 120, link="lb"),
|
||||
Section(type=SectionType.TEXT, text="C" * 20, link="lc"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].content == "A" * 120
|
||||
assert chunks[1].content == ("B" * 120) + SECTION_SEPARATOR + ("C" * 20)
|
||||
# link_offsets on second chunk: lb at 0, lc at precompare-len("BBBB...")=120
|
||||
assert chunks[1].source_links == {0: "lb", 120: "lc"}
|
||||
|
||||
|
||||
def test_small_section_after_oversized_starts_a_fresh_chunk() -> None:
|
||||
"""After an oversized section is emitted as its own chunks, the internal
|
||||
accumulator should be empty so a following small section starts a new
|
||||
chunk instead of being swallowed."""
|
||||
dc = _make_document_chunker()
|
||||
big = (
|
||||
"Alpha beta gamma. Delta epsilon zeta. Eta theta iota. "
|
||||
"Kappa lambda mu. Nu xi omicron. Pi rho sigma. Tau upsilon phi. "
|
||||
"Chi psi omega. One two three. Four five six. Seven eight nine. "
|
||||
"Ten eleven twelve. Thirteen fourteen fifteen. Sixteen seventeen."
|
||||
)
|
||||
assert len(big) > CHUNK_LIMIT
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text=big, link="l-big"),
|
||||
Section(type=SectionType.TEXT, text="Tail text.", link="l-tail"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
# All-but-last chunks belong to the oversized section; the very last is
|
||||
# the tail text starting fresh (not a continuation).
|
||||
assert len(chunks) >= 2
|
||||
assert chunks[-1].content == "Tail text."
|
||||
assert chunks[-1].source_links == {0: "l-tail"}
|
||||
assert chunks[-1].section_continuation is False
|
||||
# And earlier oversized chunks never leaked the tail link
|
||||
for c in chunks[:-1]:
|
||||
assert c.source_links == {0: "l-big"}
|
||||
|
||||
|
||||
# --- STRICT_CHUNK_TOKEN_LIMIT fallback path ----------------------------------
|
||||
|
||||
|
||||
def test_strict_chunk_token_limit_subdivides_oversized_split(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When STRICT_CHUNK_TOKEN_LIMIT is enabled and chonkie's chunk_splitter
|
||||
still produces a piece larger than content_token_limit (e.g. a single
|
||||
no-period run), the code must fall back to _split_oversized_chunk."""
|
||||
monkeypatch.setattr(text_chunker_module, "STRICT_CHUNK_TOKEN_LIMIT", True)
|
||||
dc = _make_document_chunker()
|
||||
# 500 non-whitespace chars with no sentence boundaries — chonkie will
|
||||
# return it as one oversized piece (>200) which triggers the fallback.
|
||||
run = "a" * 500
|
||||
doc = _make_doc(sections=[Section(type=SectionType.TEXT, text=run, link="l-run")])
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
# With CHUNK_LIMIT=200 and a 500-char run we expect ceil(500/200)=3 sub-chunks.
|
||||
assert len(chunks) == 3
|
||||
# First is fresh, rest are continuations (is_continuation=(j != 0))
|
||||
assert chunks[0].section_continuation is False
|
||||
assert chunks[1].section_continuation is True
|
||||
assert chunks[2].section_continuation is True
|
||||
# All carry the section link
|
||||
for c in chunks:
|
||||
assert c.source_links == {0: "l-run"}
|
||||
# NOTE: we do NOT assert the chunks are at or below content_token_limit.
|
||||
# _split_oversized_chunk joins tokens with " ", which means the resulting
|
||||
# chunk contents can exceed the limit when tokens are short. That's a
|
||||
# quirk of the current implementation and this test pins the window
|
||||
# slicing, not the post-join length.
|
||||
|
||||
|
||||
def test_strict_chunk_token_limit_disabled_allows_oversized_split(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Same pathological input, but with STRICT disabled: the oversized
|
||||
split is emitted verbatim as a single chunk (current behavior)."""
|
||||
monkeypatch.setattr(text_chunker_module, "STRICT_CHUNK_TOKEN_LIMIT", False)
|
||||
dc = _make_document_chunker()
|
||||
run = "a" * 500
|
||||
doc = _make_doc(sections=[Section(type=SectionType.TEXT, text=run, link="l-run")])
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == run
|
||||
assert chunks[0].section_continuation is False
|
||||
|
||||
|
||||
# --- First-section-with-empty-text-but-document-has-title edge case ----------
|
||||
|
||||
|
||||
def test_first_empty_section_with_title_is_processed_not_skipped() -> None:
|
||||
"""The guard `(not document.title or section_idx > 0)` means: when
|
||||
the doc has a title AND it's the first section, an empty text section
|
||||
is NOT skipped. This pins current behavior so a refactor can't silently
|
||||
change it."""
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(
|
||||
type=SectionType.TEXT, text="", link="l0"
|
||||
), # empty first section, kept
|
||||
Section(type=SectionType.TEXT, text="Real content.", link="l1"),
|
||||
],
|
||||
title="Has A Title",
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Real content."
|
||||
# First (empty) section did register a link_offset at 0 before being
|
||||
# overwritten; that offset is then reused when "Real content." is added,
|
||||
# because shared_precompare_cleanup("") is still "". End state: {0: "l1"}
|
||||
assert chunks[0].source_links == {0: "l1"}
|
||||
|
||||
|
||||
# --- clean_text is applied to section text -----------------------------------
|
||||
|
||||
|
||||
def test_clean_text_strips_control_chars_from_section_content() -> None:
|
||||
"""clean_text() should remove control chars before the text enters the
|
||||
accumulator — verifies the call isn't dropped by a refactor."""
|
||||
dc = _make_document_chunker()
|
||||
# NUL + BEL are control chars below 0x20 and not \n or \t → should be
|
||||
# stripped by clean_text.
|
||||
dirty = "Hello\x00 World\x07!"
|
||||
doc = _make_doc(sections=[Section(type=SectionType.TEXT, text=dirty, link="l1")])
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == "Hello World!"
|
||||
|
||||
|
||||
# --- None-valued fields ------------------------------------------------------
|
||||
|
||||
|
||||
def test_section_with_none_text_behaves_like_empty_string() -> None:
|
||||
"""`section.text` may be None — the method coerces via
|
||||
`str(section.text or "")`, so a None-text section behaves identically
|
||||
to an empty one (skipped unless it's the first section of a titled doc)."""
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text="Alpha.", link="la"),
|
||||
Section(type=SectionType.TEXT, text=None, link="lnone"), # idx 1 → skipped
|
||||
Section(type=SectionType.TEXT, text="Beta.", link="lb"),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert "Alpha." in chunks[0].content
|
||||
assert "Beta." in chunks[0].content
|
||||
assert "lnone" not in (chunks[0].source_links or {}).values()
|
||||
|
||||
|
||||
# --- Trailing empty chunk suppression ----------------------------------------
|
||||
|
||||
|
||||
def test_no_trailing_empty_chunk_when_last_section_was_image() -> None:
|
||||
"""If the final section was an image (which emits its own chunk and
|
||||
resets chunk_text), the safety `or not chunks` branch should NOT fire
|
||||
because chunks is non-empty. Pin this explicitly."""
|
||||
dc = _make_document_chunker()
|
||||
doc = _make_doc(
|
||||
sections=[
|
||||
Section(type=SectionType.TEXT, text="Leading text.", link="ltext"),
|
||||
Section(
|
||||
type=SectionType.IMAGE,
|
||||
text="img summary",
|
||||
link="limg",
|
||||
image_file_id="img-final",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].content == "Leading text."
|
||||
assert chunks[0].image_file_id is None
|
||||
assert chunks[1].content == "img summary"
|
||||
assert chunks[1].image_file_id == "img-final"
|
||||
# Crucially: no third empty chunk got appended at the end.
|
||||
|
||||
|
||||
def test_no_trailing_empty_chunk_when_last_section_was_oversized() -> None:
|
||||
"""Same guarantee for oversized sections: their splits fully clear the
|
||||
accumulator, and the trailing safety branch should be a no-op."""
|
||||
dc = _make_document_chunker()
|
||||
big = (
|
||||
"Alpha beta gamma. Delta epsilon zeta. Eta theta iota. "
|
||||
"Kappa lambda mu. Nu xi omicron. Pi rho sigma. Tau upsilon phi. "
|
||||
"Chi psi omega. One two three. Four five six. Seven eight nine. "
|
||||
"Ten eleven twelve. Thirteen fourteen fifteen. Sixteen seventeen."
|
||||
)
|
||||
assert len(big) > CHUNK_LIMIT
|
||||
doc = _make_doc(sections=[Section(type=SectionType.TEXT, text=big, link="l-big")])
|
||||
|
||||
chunks = dc.chunk(
|
||||
document=doc,
|
||||
sections=doc.processed_sections,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
content_token_limit=CHUNK_LIMIT,
|
||||
)
|
||||
|
||||
# Every chunk should be non-empty — no dangling "" chunk at the tail.
|
||||
assert all(c.content.strip() for c in chunks)
|
||||
@@ -0,0 +1,188 @@
|
||||
"""Unit tests for MinimalPersonaSnapshot.from_model knowledge_sources aggregation."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.server.features.document_set.models import DocumentSetSummary
|
||||
from onyx.server.features.persona.models import MinimalPersonaSnapshot
|
||||
|
||||
|
||||
_STUB_DS_SUMMARY = DocumentSetSummary(
|
||||
id=1,
|
||||
name="stub",
|
||||
description=None,
|
||||
cc_pair_summaries=[],
|
||||
is_up_to_date=True,
|
||||
is_public=True,
|
||||
users=[],
|
||||
groups=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_persona(**overrides: object) -> MagicMock:
|
||||
"""Build a mock Persona with sensible defaults.
|
||||
|
||||
Every relationship defaults to empty so tests only need to set the
|
||||
fields they care about.
|
||||
"""
|
||||
p = MagicMock()
|
||||
p.id = 1
|
||||
p.name = "test"
|
||||
p.description = ""
|
||||
p.tools = []
|
||||
p.starter_messages = None
|
||||
p.document_sets = []
|
||||
p.hierarchy_nodes = []
|
||||
p.attached_documents = []
|
||||
p.user_files = []
|
||||
p.llm_model_version_override = None
|
||||
p.llm_model_provider_override = None
|
||||
p.uploaded_image_id = None
|
||||
p.icon_name = None
|
||||
p.is_public = True
|
||||
p.is_listed = True
|
||||
p.display_priority = None
|
||||
p.is_featured = False
|
||||
p.builtin_persona = False
|
||||
p.labels = []
|
||||
p.user = None
|
||||
|
||||
for k, v in overrides.items():
|
||||
setattr(p, k, v)
|
||||
return p
|
||||
|
||||
|
||||
def _make_cc_pair(source: DocumentSource) -> MagicMock:
|
||||
cc = MagicMock()
|
||||
cc.connector.source = source
|
||||
cc.name = source.value
|
||||
cc.id = 1
|
||||
cc.access_type = "PUBLIC"
|
||||
return cc
|
||||
|
||||
|
||||
def _make_doc_set(
|
||||
cc_pairs: list[MagicMock] | None = None,
|
||||
fed_connectors: list[MagicMock] | None = None,
|
||||
) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = 1
|
||||
ds.name = "ds"
|
||||
ds.description = None
|
||||
ds.is_up_to_date = True
|
||||
ds.is_public = True
|
||||
ds.users = []
|
||||
ds.groups = []
|
||||
ds.connector_credential_pairs = cc_pairs or []
|
||||
ds.federated_connectors = fed_connectors or []
|
||||
return ds
|
||||
|
||||
|
||||
def _make_federated_ds_mapping(
|
||||
source: FederatedConnectorSource,
|
||||
) -> MagicMock:
|
||||
mapping = MagicMock()
|
||||
mapping.federated_connector.source = source
|
||||
mapping.federated_connector_id = 1
|
||||
mapping.entities = {}
|
||||
return mapping
|
||||
|
||||
|
||||
def _make_hierarchy_node(source: DocumentSource) -> MagicMock:
|
||||
node = MagicMock()
|
||||
node.source = source
|
||||
return node
|
||||
|
||||
|
||||
def _make_attached_document(source: DocumentSource) -> MagicMock:
|
||||
doc = MagicMock()
|
||||
doc.parent_hierarchy_node = MagicMock()
|
||||
doc.parent_hierarchy_node.source = source
|
||||
return doc
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_empty_persona_has_no_knowledge_sources(_mock_ds: MagicMock) -> None:
|
||||
persona = _make_persona()
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert snapshot.knowledge_sources == []
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_user_files_adds_user_file_source(_mock_ds: MagicMock) -> None:
|
||||
persona = _make_persona(user_files=[MagicMock()])
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert DocumentSource.USER_FILE in snapshot.knowledge_sources
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_no_user_files_excludes_user_file_source(_mock_ds: MagicMock) -> None:
|
||||
cc = _make_cc_pair(DocumentSource.CONFLUENCE)
|
||||
ds = _make_doc_set(cc_pairs=[cc])
|
||||
persona = _make_persona(document_sets=[ds])
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert DocumentSource.USER_FILE not in snapshot.knowledge_sources
|
||||
assert DocumentSource.CONFLUENCE in snapshot.knowledge_sources
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_federated_connector_in_doc_set(_mock_ds: MagicMock) -> None:
|
||||
fed = _make_federated_ds_mapping(FederatedConnectorSource.FEDERATED_SLACK)
|
||||
ds = _make_doc_set(fed_connectors=[fed])
|
||||
persona = _make_persona(document_sets=[ds])
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert DocumentSource.SLACK in snapshot.knowledge_sources
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_hierarchy_nodes_and_attached_documents(_mock_ds: MagicMock) -> None:
|
||||
node = _make_hierarchy_node(DocumentSource.GOOGLE_DRIVE)
|
||||
doc = _make_attached_document(DocumentSource.SHAREPOINT)
|
||||
persona = _make_persona(hierarchy_nodes=[node], attached_documents=[doc])
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
assert DocumentSource.GOOGLE_DRIVE in snapshot.knowledge_sources
|
||||
assert DocumentSource.SHAREPOINT in snapshot.knowledge_sources
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.models.DocumentSetSummary.from_model",
|
||||
return_value=_STUB_DS_SUMMARY,
|
||||
)
|
||||
def test_all_source_types_combined(_mock_ds: MagicMock) -> None:
|
||||
cc = _make_cc_pair(DocumentSource.CONFLUENCE)
|
||||
fed = _make_federated_ds_mapping(FederatedConnectorSource.FEDERATED_SLACK)
|
||||
ds = _make_doc_set(cc_pairs=[cc], fed_connectors=[fed])
|
||||
node = _make_hierarchy_node(DocumentSource.GOOGLE_DRIVE)
|
||||
doc = _make_attached_document(DocumentSource.SHAREPOINT)
|
||||
persona = _make_persona(
|
||||
document_sets=[ds],
|
||||
hierarchy_nodes=[node],
|
||||
attached_documents=[doc],
|
||||
user_files=[MagicMock()],
|
||||
)
|
||||
snapshot = MinimalPersonaSnapshot.from_model(persona)
|
||||
sources = set(snapshot.knowledge_sources)
|
||||
assert sources == {
|
||||
DocumentSource.CONFLUENCE,
|
||||
DocumentSource.SLACK,
|
||||
DocumentSource.GOOGLE_DRIVE,
|
||||
DocumentSource.SHAREPOINT,
|
||||
DocumentSource.USER_FILE,
|
||||
}
|
||||
@@ -95,9 +95,9 @@ class TestForceAddSearchToolGuard:
|
||||
without a vector DB."""
|
||||
import inspect
|
||||
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import _construct_tools_impl
|
||||
|
||||
source = inspect.getsource(construct_tools)
|
||||
source = inspect.getsource(_construct_tools_impl)
|
||||
assert (
|
||||
"DISABLE_VECTOR_DB" in source
|
||||
), "construct_tools should reference DISABLE_VECTOR_DB to suppress force-adding SearchTool"
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
"""Tests for generic Celery task lifecycle Prometheus metrics."""
|
||||
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery.apps.app_base import on_before_task_publish
|
||||
from onyx.server.metrics.celery_task_metrics import _task_start_times
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_COMPLETED
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_DURATION
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_QUEUE_WAIT
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_STARTED
|
||||
from onyx.server.metrics.celery_task_metrics import TASKS_ACTIVE
|
||||
|
||||
@@ -22,11 +25,18 @@ def reset_metrics() -> Iterator[None]:
|
||||
_task_start_times.clear()
|
||||
|
||||
|
||||
def _make_task(name: str = "test_task", queue: str = "test_queue") -> MagicMock:
|
||||
def _make_task(
|
||||
name: str = "test_task",
|
||||
queue: str = "test_queue",
|
||||
enqueued_at: float | None = None,
|
||||
) -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.name = name
|
||||
task.request = MagicMock()
|
||||
task.request.delivery_info = {"routing_key": queue}
|
||||
task.request.headers = (
|
||||
{"enqueued_at": enqueued_at} if enqueued_at is not None else {}
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@@ -72,6 +82,35 @@ class TestCeleryTaskPrerun:
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
def test_observes_queue_wait_when_enqueued_at_present(self) -> None:
|
||||
enqueued_at = time.time() - 30 # simulates 30s wait
|
||||
task = _make_task(enqueued_at=enqueued_at)
|
||||
|
||||
before = TASK_QUEUE_WAIT.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
after = TASK_QUEUE_WAIT.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
assert after >= before + 30
|
||||
|
||||
def test_skips_queue_wait_when_enqueued_at_missing(self) -> None:
|
||||
task = _make_task() # no enqueued_at in headers
|
||||
|
||||
before = TASK_QUEUE_WAIT.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
|
||||
on_celery_task_prerun("task-2", task)
|
||||
|
||||
after = TASK_QUEUE_WAIT.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
assert after == before
|
||||
|
||||
|
||||
class TestCeleryTaskPostrun:
|
||||
def test_increments_completed_success(self) -> None:
|
||||
@@ -151,3 +190,15 @@ class TestCeleryTaskPostrun:
|
||||
task = _make_task()
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
# Should not raise
|
||||
|
||||
|
||||
class TestBeforeTaskPublish:
|
||||
def test_stamps_enqueued_at_into_headers(self) -> None:
|
||||
before = time.time()
|
||||
headers: dict = {}
|
||||
on_before_task_publish(headers=headers)
|
||||
assert "enqueued_at" in headers
|
||||
assert headers["enqueued_at"] >= before
|
||||
|
||||
def test_noop_when_headers_is_none(self) -> None:
|
||||
on_before_task_publish(headers=None) # should not raise
|
||||
|
||||
204
backend/tests/unit/server/metrics/test_deletion_metrics.py
Normal file
204
backend/tests/unit/server/metrics/test_deletion_metrics.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Tests for deletion-specific Prometheus metrics."""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_BLOCKED
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_COMPLETED
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_FENCE_RESET
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_STARTED
|
||||
from onyx.server.metrics.deletion_metrics import DELETION_TASKSET_DURATION
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
|
||||
|
||||
class TestIncDeletionStarted:
|
||||
def test_increments_counter(self) -> None:
|
||||
before = DELETION_STARTED.labels(tenant_id="t1")._value.get()
|
||||
|
||||
inc_deletion_started("t1")
|
||||
|
||||
after = DELETION_STARTED.labels(tenant_id="t1")._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_labels_by_tenant(self) -> None:
|
||||
before_t1 = DELETION_STARTED.labels(tenant_id="t1")._value.get()
|
||||
before_t2 = DELETION_STARTED.labels(tenant_id="t2")._value.get()
|
||||
|
||||
inc_deletion_started("t1")
|
||||
|
||||
assert DELETION_STARTED.labels(tenant_id="t1")._value.get() == before_t1 + 1
|
||||
assert DELETION_STARTED.labels(tenant_id="t2")._value.get() == before_t2
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_STARTED,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
inc_deletion_started("t1")
|
||||
|
||||
|
||||
class TestIncDeletionCompleted:
|
||||
def test_increments_counter(self) -> None:
|
||||
before = DELETION_COMPLETED.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._value.get()
|
||||
|
||||
inc_deletion_completed("t1", "success")
|
||||
|
||||
after = DELETION_COMPLETED.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_labels_by_outcome(self) -> None:
|
||||
before_success = DELETION_COMPLETED.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._value.get()
|
||||
before_failure = DELETION_COMPLETED.labels(
|
||||
tenant_id="t1", outcome="failure"
|
||||
)._value.get()
|
||||
|
||||
inc_deletion_completed("t1", "success")
|
||||
|
||||
assert (
|
||||
DELETION_COMPLETED.labels(tenant_id="t1", outcome="success")._value.get()
|
||||
== before_success + 1
|
||||
)
|
||||
assert (
|
||||
DELETION_COMPLETED.labels(tenant_id="t1", outcome="failure")._value.get()
|
||||
== before_failure
|
||||
)
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_COMPLETED,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
inc_deletion_completed("t1", "success")
|
||||
|
||||
|
||||
class TestObserveDeletionTasksetDuration:
|
||||
def test_observes_duration(self) -> None:
|
||||
before = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get()
|
||||
|
||||
observe_deletion_taskset_duration("t1", "success", 120.0)
|
||||
|
||||
after = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get()
|
||||
assert after == pytest.approx(before + 120.0)
|
||||
|
||||
def test_labels_by_tenant(self) -> None:
|
||||
before_t1 = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get()
|
||||
before_t2 = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t2", outcome="success"
|
||||
)._sum.get()
|
||||
|
||||
observe_deletion_taskset_duration("t1", "success", 60.0)
|
||||
|
||||
assert DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get() == pytest.approx(before_t1 + 60.0)
|
||||
assert DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t2", outcome="success"
|
||||
)._sum.get() == pytest.approx(before_t2)
|
||||
|
||||
def test_labels_by_outcome(self) -> None:
|
||||
before_success = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get()
|
||||
before_failure = DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="failure"
|
||||
)._sum.get()
|
||||
|
||||
observe_deletion_taskset_duration("t1", "failure", 45.0)
|
||||
|
||||
assert DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="success"
|
||||
)._sum.get() == pytest.approx(before_success)
|
||||
assert DELETION_TASKSET_DURATION.labels(
|
||||
tenant_id="t1", outcome="failure"
|
||||
)._sum.get() == pytest.approx(before_failure + 45.0)
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_TASKSET_DURATION,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
observe_deletion_taskset_duration("t1", "success", 10.0)
|
||||
|
||||
|
||||
class TestIncDeletionBlocked:
|
||||
def test_increments_counter(self) -> None:
|
||||
before = DELETION_BLOCKED.labels(
|
||||
tenant_id="t1", blocker="indexing"
|
||||
)._value.get()
|
||||
|
||||
inc_deletion_blocked("t1", "indexing")
|
||||
|
||||
after = DELETION_BLOCKED.labels(tenant_id="t1", blocker="indexing")._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_labels_by_blocker(self) -> None:
|
||||
before_idx = DELETION_BLOCKED.labels(
|
||||
tenant_id="t1", blocker="indexing"
|
||||
)._value.get()
|
||||
before_prune = DELETION_BLOCKED.labels(
|
||||
tenant_id="t1", blocker="pruning"
|
||||
)._value.get()
|
||||
|
||||
inc_deletion_blocked("t1", "indexing")
|
||||
|
||||
assert (
|
||||
DELETION_BLOCKED.labels(tenant_id="t1", blocker="indexing")._value.get()
|
||||
== before_idx + 1
|
||||
)
|
||||
assert (
|
||||
DELETION_BLOCKED.labels(tenant_id="t1", blocker="pruning")._value.get()
|
||||
== before_prune
|
||||
)
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_BLOCKED,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
inc_deletion_blocked("t1", "indexing")
|
||||
|
||||
|
||||
class TestIncDeletionFenceReset:
|
||||
def test_increments_counter(self) -> None:
|
||||
before = DELETION_FENCE_RESET.labels(tenant_id="t1")._value.get()
|
||||
|
||||
inc_deletion_fence_reset("t1")
|
||||
|
||||
after = DELETION_FENCE_RESET.labels(tenant_id="t1")._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_labels_by_tenant(self) -> None:
|
||||
before_t1 = DELETION_FENCE_RESET.labels(tenant_id="t1")._value.get()
|
||||
before_t2 = DELETION_FENCE_RESET.labels(tenant_id="t2")._value.get()
|
||||
|
||||
inc_deletion_fence_reset("t1")
|
||||
|
||||
assert DELETION_FENCE_RESET.labels(tenant_id="t1")._value.get() == before_t1 + 1
|
||||
assert DELETION_FENCE_RESET.labels(tenant_id="t2")._value.get() == before_t2
|
||||
|
||||
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
DELETION_FENCE_RESET,
|
||||
"labels",
|
||||
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
inc_deletion_fence_reset("t1")
|
||||
@@ -0,0 +1,26 @@
|
||||
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
|
||||
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
|
||||
{{- if gt (int .Values.celery_worker_light.replicaCount) 0 }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-light-metrics
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_light.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_light.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
metrics: "true"
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 9095
|
||||
targetPort: metrics
|
||||
protocol: TCP
|
||||
name: metrics
|
||||
selector:
|
||||
{{- include "onyx.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_light.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_light.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -70,6 +70,10 @@ spec:
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,opensearch_migration",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9095
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_light.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -99,4 +99,29 @@ spec:
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- if gt (int .Values.celery_worker_light.replicaCount) 0 }}
|
||||
---
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-light
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.celery_worker_light.deploymentLabels.app }}
|
||||
metrics: "true"
|
||||
endpoints:
|
||||
- port: metrics
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user