mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-29 19:42:41 +00:00
Compare commits
77 Commits
jamison/wo
...
cli/v0.1.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d725df62e7 | ||
|
|
d1460972b6 | ||
|
|
706872f0b7 | ||
|
|
ed3856be2b | ||
|
|
6326c7f0b9 | ||
|
|
40420fc4e6 | ||
|
|
1a2b6a66cc | ||
|
|
d1b1529ccf | ||
|
|
fedd9c76e5 | ||
|
|
0b34b40b79 | ||
|
|
fe82ddb1b9 | ||
|
|
32d3d70525 | ||
|
|
40b9e10890 | ||
|
|
e21b204b8a | ||
|
|
2f672b3a4f | ||
|
|
cf19d0df4f | ||
|
|
86a6a4c134 | ||
|
|
146b5449d2 | ||
|
|
b66991b5c5 | ||
|
|
9cb76dc027 | ||
|
|
f66891d19e | ||
|
|
c07c952ad5 | ||
|
|
be7f40a28a | ||
|
|
26f941b5da | ||
|
|
b9e84c42a8 | ||
|
|
0a1df52c2f | ||
|
|
306b0d452f | ||
|
|
5fdb34ba8e | ||
|
|
2d066631e3 | ||
|
|
5c84f6c61b | ||
|
|
899179d4b6 | ||
|
|
80d6bafc74 | ||
|
|
2cc325cb0e | ||
|
|
849385b756 | ||
|
|
417b9c12e4 | ||
|
|
30b37d0a77 | ||
|
|
b48be0cd3a | ||
|
|
127fd90424 | ||
|
|
f9c9e55f32 | ||
|
|
5afcf1acea | ||
|
|
eb1244a9d7 | ||
|
|
2433a9a4c5 | ||
|
|
60bc8fcac6 | ||
|
|
1ddc958a51 | ||
|
|
de37acbe07 | ||
|
|
08cd2f2c3e | ||
|
|
fc29f20914 | ||
|
|
c43cb80a7a | ||
|
|
56f0be2ec8 | ||
|
|
42f9ddf247 | ||
|
|
a10a85c73c | ||
|
|
31d8ae9718 | ||
|
|
00a0a99842 | ||
|
|
90040f8973 | ||
|
|
4f5d081f26 | ||
|
|
c51a6dbd0d | ||
|
|
8b90ecc189 | ||
|
|
865c893a09 | ||
|
|
ef5628bfa7 | ||
|
|
6ffee0021e | ||
|
|
28dc84b831 | ||
|
|
230f035500 | ||
|
|
55b24d72b4 | ||
|
|
3321a84c7d | ||
|
|
54bf32a5f8 | ||
|
|
4bb6b76be6 | ||
|
|
db94562474 | ||
|
|
582d4642c1 | ||
|
|
3caaecdb0e | ||
|
|
039b69806b | ||
|
|
63971d4958 | ||
|
|
ffd897f380 | ||
|
|
4745069232 | ||
|
|
386782f188 | ||
|
|
ff009c4129 | ||
|
|
b20a5ebf69 | ||
|
|
8645adb807 |
4
.github/workflows/deployment.yml
vendored
4
.github/workflows/deployment.yml
vendored
@@ -615,6 +615,7 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1263,8 +1264,6 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=craft-latest
|
||||
# TODO: Consider aligning craft-latest tags with regular backend builds (e.g., latest, edge, beta)
|
||||
# to keep tagging strategy consistent across all backend images
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
@@ -1488,6 +1487,7 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
|
||||
3
.github/workflows/helm-chart-releases.yml
vendored
3
.github/workflows/helm-chart-releases.yml
vendored
@@ -47,7 +47,8 @@ jobs:
|
||||
done
|
||||
|
||||
- name: Publish Helm charts to gh-pages
|
||||
uses: stefanprodan/helm-gh-pages@0ad2bb377311d61ac04ad9eb6f252fb68e207260 # ratchet:stefanprodan/helm-gh-pages@v1.7.0
|
||||
# NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43
|
||||
uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
charts_dir: deployment/helm/charts
|
||||
|
||||
17
.github/workflows/release-cli.yml
vendored
17
.github/workflows/release-cli.yml
vendored
@@ -13,15 +13,6 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
@@ -31,9 +22,11 @@ jobs:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
for goos in linux windows darwin; do
|
||||
for goarch in amd64 arm64; do
|
||||
GOOS="$goos" GOARCH="$goarch" uv build --wheel
|
||||
done
|
||||
done
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
64
.greptile/config.json
Normal file
64
.greptile/config.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
57
.greptile/files.json
Normal file
57
.greptile/files.json
Normal file
@@ -0,0 +1,57 @@
|
||||
[
|
||||
{
|
||||
"scope": [],
|
||||
"path": "contributing_guides/best_practices.md",
|
||||
"description": "Best practices for contributing to the codebase"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/AGENTS.md",
|
||||
"description": "Frontend coding standards for the web directory"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/tests/README.md",
|
||||
"description": "Frontend testing guide and conventions"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/CLAUDE.md",
|
||||
"description": "Single source of truth for frontend coding standards"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/lib/opal/README.md",
|
||||
"description": "Opal component library usage guide"
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**"],
|
||||
"path": "backend/tests/README.md",
|
||||
"description": "Backend testing guide covering all 4 test types, fixtures, and conventions"
|
||||
},
|
||||
{
|
||||
"scope": ["backend/onyx/connectors/**"],
|
||||
"path": "backend/onyx/connectors/README.md",
|
||||
"description": "Connector development guide covering design, interfaces, and required changes"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "backend/alembic/README.md",
|
||||
"description": "Migration guidance, including multi-tenant migration behavior"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/helm/charts/onyx/values-lite.yaml",
|
||||
"description": "Lite deployment Helm values and service assumptions"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
|
||||
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
|
||||
}
|
||||
]
|
||||
39
.greptile/rules.md
Normal file
39
.greptile/rules.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Greptile Review Rules
|
||||
|
||||
## Type Annotations
|
||||
|
||||
Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code.
|
||||
|
||||
## Best Practices
|
||||
|
||||
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
|
||||
|
||||
## TODOs
|
||||
|
||||
Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of `TODO(name): ...` or `TODO(1234): ...`
|
||||
|
||||
## Debugging Code
|
||||
|
||||
Remove temporary debugging code before merging to production, especially tenant-specific debugging logs.
|
||||
|
||||
## Hardcoded Booleans
|
||||
|
||||
When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant.
|
||||
|
||||
## Multi-tenant vs Single-tenant
|
||||
|
||||
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
|
||||
|
||||
## Nginx Routing — New Backend Routes
|
||||
|
||||
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
|
||||
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
|
||||
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
|
||||
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
|
||||
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
|
||||
|
||||
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
|
||||
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
@@ -122,7 +122,7 @@ repos:
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
language_version: "1.26.0"
|
||||
language_version: "1.26.1"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
|
||||
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
@@ -117,7 +117,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
"consoleTitle": "API Server Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
@@ -268,7 +269,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
"consoleTitle": "Celery heavy Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery kg_processing",
|
||||
@@ -355,7 +357,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user_file_processing Console"
|
||||
"consoleTitle": "Celery user_file_processing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
@@ -413,7 +416,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console"
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
|
||||
@@ -35,7 +35,7 @@ Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep
|
||||
> [!TIP]
|
||||
> Run Onyx with one command (or see deployment section below):
|
||||
> ```
|
||||
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
|
||||
> curl -fsSL https://onyx.app/install_onyx.sh | bash
|
||||
> ```
|
||||
|
||||
****
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""remove voice_provider deleted column
|
||||
|
||||
Revision ID: 1d78c0ca7853
|
||||
Revises: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-26 11:30:53.883127
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d78c0ca7853"
|
||||
down_revision = "a3f8b2c1d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Hard-delete any soft-deleted rows before dropping the column
|
||||
op.execute("DELETE FROM voice_provider WHERE deleted = true")
|
||||
op.drop_column("voice_provider", "deleted")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"voice_provider",
|
||||
sa.Column(
|
||||
"deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
"""group_permissions_phase1
|
||||
|
||||
Revision ID: 25a5501dc766
|
||||
Revises: b728689f45b1
|
||||
Create Date: 2026-03-23 11:41:25.557442
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "25a5501dc766"
|
||||
down_revision = "b728689f45b1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Add account_type column to user table (nullable for now).
|
||||
# TODO(subash): backfill account_type for existing rows and add NOT NULL.
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"account_type",
|
||||
sa.Enum(AccountType, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 2. Add is_default column to user_group table
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"is_default",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Create permission_grant table
|
||||
op.create_table(
|
||||
"permission_grant",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"permission",
|
||||
sa.Enum(Permission, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"grant_source",
|
||||
sa.Enum(GrantSource, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["group_id"],
|
||||
["user_group.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["granted_by"],
|
||||
["user.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Index on user__user_group(user_id) — existing composite PK
|
||||
# has user_group_id as leading column; user-filtered queries need this
|
||||
op.create_index(
|
||||
"ix_user__user_group_user_id",
|
||||
"user__user_group",
|
||||
["user_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_user__user_group_user_id", table_name="user__user_group")
|
||||
op.drop_table("permission_grant")
|
||||
op.drop_column("user_group", "is_default")
|
||||
op.drop_column("user", "account_type")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add preferred_response_id and model_display_name to chat_message
|
||||
|
||||
Revision ID: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-22
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3f8b2c1d4e5"
|
||||
down_revision = "25a5501dc766"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"preferred_response_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("model_display_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "model_display_name")
|
||||
op.drop_column("chat_message", "preferred_response_id")
|
||||
@@ -28,6 +28,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -187,7 +188,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -227,6 +227,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_permission_sync_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
@@ -473,6 +474,8 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -29,6 +29,7 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -162,7 +163,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -221,6 +221,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -115,8 +115,14 @@ def fetch_user_group_token_rate_limits_for_user(
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
stmt = select(TokenRateLimit)
|
||||
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
|
||||
stmt = (
|
||||
select(TokenRateLimit)
|
||||
.join(
|
||||
TokenRateLimit__UserGroup,
|
||||
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
||||
)
|
||||
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if enabled_only:
|
||||
|
||||
@@ -250,20 +250,24 @@ def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
|
||||
raise e
|
||||
|
||||
|
||||
def _is_public_item(drive_item: DriveItem) -> bool:
|
||||
is_public = False
|
||||
def _is_public_item(
|
||||
drive_item: DriveItem,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> bool:
|
||||
if not treat_sharing_link_as_public:
|
||||
return False
|
||||
|
||||
try:
|
||||
permissions = sleep_and_retry(
|
||||
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
|
||||
)
|
||||
for permission in permissions:
|
||||
if permission.link and (
|
||||
permission.link.scope == "anonymous"
|
||||
or permission.link.scope == "organization"
|
||||
if permission.link and permission.link.scope in (
|
||||
"anonymous",
|
||||
"organization",
|
||||
):
|
||||
is_public = True
|
||||
break
|
||||
return is_public
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
|
||||
return False
|
||||
@@ -504,6 +508,7 @@ def get_external_access_from_sharepoint(
|
||||
drive_item: DriveItem | None,
|
||||
site_page: dict[str, Any] | None,
|
||||
add_prefix: bool = False,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get external access information from SharePoint.
|
||||
@@ -563,8 +568,7 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
|
||||
if drive_item and drive_name:
|
||||
# Here we check if the item have have any public links, if so we return early
|
||||
is_public = _is_public_item(drive_item)
|
||||
is_public = _is_public_item(drive_item, treat_sharing_link_as_public)
|
||||
if is_public:
|
||||
logger.info(f"Item {drive_item.id} is public")
|
||||
return ExternalAccess(
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
@@ -105,9 +106,11 @@ def _get_slack_document_access(
|
||||
slack_connector: SlackConnector,
|
||||
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
indexing_start: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback
|
||||
callback=callback,
|
||||
start=indexing_start,
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
@@ -180,9 +183,15 @@ def slack_doc_sync(
|
||||
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.set_credentials_provider(provider)
|
||||
indexing_start_ts: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
yield from _get_slack_document_access(
|
||||
slack_connector,
|
||||
slack_connector=slack_connector,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
indexing_start=indexing_start_ts,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -40,10 +41,19 @@ def generic_doc_sync(
|
||||
|
||||
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
|
||||
|
||||
indexing_start: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=indexing_start,
|
||||
callback=callback,
|
||||
):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -44,19 +44,21 @@ def _run_single_search(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
num_hits: int | None = None,
|
||||
hybrid_alpha: float | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
limit=num_hits,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona=None, # No persona for direct search
|
||||
persona_search_info=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -74,7 +76,7 @@ def stream_search_query(
|
||||
Core search function that yields streaming packets.
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
# Get document index.
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
@@ -119,6 +121,7 @@ def stream_search_query(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
num_hits=request.num_hits,
|
||||
hybrid_alpha=request.hybrid_alpha,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
@@ -133,6 +136,7 @@ def stream_search_query(
|
||||
user,
|
||||
db_session,
|
||||
request.num_hits,
|
||||
request.hybrid_alpha,
|
||||
),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
|
||||
@@ -27,15 +27,17 @@ class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any
|
||||
# changes to it should be reviewed and approved by an experienced team member.
|
||||
# It is very important to 1. avoid bloat and 2. that this remains backwards
|
||||
# compatible across versions.
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 30
|
||||
|
||||
hybrid_alpha: float | None = None
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
@@ -67,8 +68,10 @@ def search_flow_classification(
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any
|
||||
# changes to it should be reviewed and approved by an experienced team member.
|
||||
# It is very important to 1. avoid bloat and 2. that this remains backwards
|
||||
# compatible across versions.
|
||||
@router.post(
|
||||
"/send-search-message",
|
||||
response_model=None,
|
||||
@@ -80,13 +83,19 @@ def handle_send_search_message(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
Execute a search query with optional streaming.
|
||||
Executes a search query with optional streaming.
|
||||
|
||||
When stream=True: Returns StreamingResponse with SSE
|
||||
When stream=False: Returns SearchFullResponse
|
||||
If hybrid_alpha is unset and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
|
||||
is True, executes pure keyword search.
|
||||
|
||||
Returns:
|
||||
StreamingResponse with SSE if stream=True, otherwise SearchFullResponse.
|
||||
"""
|
||||
logger.debug(f"Received search query: {request.search_query}")
|
||||
|
||||
if request.hybrid_alpha is None and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH:
|
||||
request.hybrid_alpha = 0.0
|
||||
|
||||
# Non-streaming path
|
||||
if not request.stream:
|
||||
try:
|
||||
|
||||
@@ -13,6 +13,14 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -34,6 +42,8 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -48,6 +58,36 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -76,6 +116,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docfetching")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,14 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -35,6 +43,8 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -49,6 +59,36 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -82,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docprocessing")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -90,6 +131,12 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
# Note: worker_process_init only fires in prefork pool mode. Docprocessing uses
|
||||
# worker_pool="threads" (see configs/docprocessing.py), so this handler is
|
||||
# effectively a no-op in normal operation. It remains as a safety net in case
|
||||
# the pool type is ever changed to prefork. Prometheus metrics are safe in
|
||||
# thread-pool mode since all threads share the same process memory and can
|
||||
# update the same Counter/Gauge/Histogram objects directly.
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
@@ -54,8 +54,14 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
# Set by on_worker_init so on_worker_ready knows whether to start the server.
|
||||
_prometheus_collectors_ok: bool = False
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
global _prometheus_collectors_ok
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
@@ -65,6 +71,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
_prometheus_collectors_ok = _setup_prometheus_collectors(sender)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
@@ -72,8 +80,37 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
def _setup_prometheus_collectors(sender: Any) -> bool:
|
||||
"""Register Prometheus collectors that need Redis/DB access.
|
||||
|
||||
Passes the Celery app so the queue depth collector can obtain a fresh
|
||||
broker Redis client on each scrape (rather than holding a stale reference).
|
||||
|
||||
Returns True if registration succeeded, False otherwise.
|
||||
"""
|
||||
try:
|
||||
from onyx.server.metrics.indexing_pipeline_setup import (
|
||||
setup_indexing_pipeline_metrics,
|
||||
)
|
||||
|
||||
setup_indexing_pipeline_metrics(sender.app)
|
||||
logger.info("Prometheus indexing pipeline collectors registered")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to register Prometheus indexing pipeline collectors")
|
||||
return False
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
if _prometheus_collectors_ok:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
start_metrics_server("monitoring")
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping Prometheus metrics server — collector registration failed"
|
||||
)
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -7,7 +8,59 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
_broker_client: Redis | None = None
|
||||
_broker_url: str | None = None
|
||||
_broker_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def celery_get_broker_client(app: Celery) -> Redis:
|
||||
"""Return a shared Redis client connected to the Celery broker DB.
|
||||
|
||||
Uses a module-level singleton so all tasks on a worker share one
|
||||
connection instead of creating a new one per call. The client
|
||||
connects directly to the broker Redis DB (parsed from the broker URL).
|
||||
|
||||
Thread-safe via lock — safe for use in Celery thread-pool workers.
|
||||
|
||||
Usage:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
length = celery_get_queue_length(queue, r_celery)
|
||||
"""
|
||||
global _broker_client, _broker_url
|
||||
with _broker_client_lock:
|
||||
url = app.conf.broker_url
|
||||
if _broker_client is not None and _broker_url == url:
|
||||
try:
|
||||
_broker_client.ping()
|
||||
return _broker_client
|
||||
except Exception:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
elif _broker_client is not None:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
|
||||
_broker_url = url
|
||||
_broker_client = Redis.from_url(
|
||||
url,
|
||||
decode_responses=False,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _broker_client
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
|
||||
@@ -14,6 +14,7 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -132,7 +133,6 @@ def revoke_tasks_blocking_deletion(
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -149,6 +149,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -449,7 +450,7 @@ def check_indexing_completion(
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(task.app)
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
@@ -19,6 +18,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -698,31 +698,27 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
# Collect queue metrics with broker connection
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_metrics = _collect_queue_metrics(r_celery)
|
||||
|
||||
# Collect and log each metric
|
||||
# Collect remaining metrics (no broker connection needed)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
all_metrics: list[Metric] = queue_metrics
|
||||
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
|
||||
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
for metric in all_metrics:
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -890,7 +886,7 @@ def monitor_celery_queues_helper(
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(task.app)
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -1080,7 +1076,7 @@ def cloud_monitor_celery_pidbox(
|
||||
num_deleted = 0
|
||||
|
||||
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -203,7 +204,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -261,6 +261,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
@@ -105,7 +106,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(celery_app)
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
@@ -238,7 +239,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -591,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
|
||||
@@ -8,6 +8,7 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -35,7 +36,13 @@ class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| MessageResponseIDInfo
|
||||
| MultiModelMessageResponseIDInfo
|
||||
| StreamingError
|
||||
| CreateChatSessionID
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
|
||||
|
||||
@@ -44,6 +44,31 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
|
||||
# Hard ceiling for the admin-configurable file upload size (in MB).
|
||||
# Self-hosted customers can raise or lower this via the environment variable.
|
||||
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
|
||||
if _raw_max_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
|
||||
_raw_max_upload_size_mb,
|
||||
)
|
||||
_raw_max_upload_size_mb = 250
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
|
||||
|
||||
# Default fallback for the per-user file upload size limit (in MB) when no
|
||||
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
|
||||
# runtime so this never silently exceeds the hard ceiling.
|
||||
_raw_default_upload_size_mb = int(
|
||||
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
|
||||
)
|
||||
if _raw_default_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
|
||||
_raw_default_upload_size_mb,
|
||||
)
|
||||
_raw_default_upload_size_mb = 100
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
|
||||
) # 1 day
|
||||
@@ -61,17 +86,6 @@ CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -332,6 +346,10 @@ OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
|
||||
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
|
||||
else None
|
||||
)
|
||||
ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH = (
|
||||
os.environ.get("ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -24,11 +24,11 @@ CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
|
||||
LLM_SOCKET_READ_TIMEOUT = int(
|
||||
os.environ.get("LLM_SOCKET_READ_TIMEOUT") or "60"
|
||||
) # 60 seconds
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
# Weighting factor between vector and keyword Search; 1 for completely vector
|
||||
# search, 0 for keyword. Enforces a valid range of [0, 1]. A supplied value from
|
||||
# the env outside of this range will be clipped to the respective end of the
|
||||
# range. Defaults to 0.5.
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
|
||||
HYBRID_ALPHA_KEYWORD = max(
|
||||
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
|
||||
)
|
||||
# Weighting factor between Title and Content of documents during search, 1 for completely
|
||||
# Title based. Default heavily favors Content because Title is also included at the top of
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
|
||||
0
backend/onyx/connectors/canvas/__init__.py
Normal file
0
backend/onyx/connectors/canvas/__init__.py
Normal file
192
backend/onyx/connectors/canvas/client.py
Normal file
192
backend/onyx/connectors/canvas/client.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Requests timeout in seconds.
|
||||
_CANVAS_CALL_TIMEOUT: int = 30
|
||||
_CANVAS_API_VERSION: str = "/api/v1"
|
||||
# Matches the "next" URL in a Canvas Link header, e.g.:
|
||||
# <https://canvas.example.com/api/v1/courses?page=2>; rel="next"
|
||||
# Captures the URL inside the angle brackets.
|
||||
_NEXT_LINK_PATTERN: re.Pattern[str] = re.compile(r'<([^>]+)>;\s*rel="next"')
|
||||
|
||||
|
||||
_STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
|
||||
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
|
||||
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: OnyxErrorCode.BAD_GATEWAY,
|
||||
429: OnyxErrorCode.RATE_LIMITED,
|
||||
}
|
||||
|
||||
|
||||
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
|
||||
"""Map an HTTP status code to the appropriate OnyxErrorCode.
|
||||
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
|
||||
mapped to specific error codes; all other codes (unrecognised 4xx
|
||||
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
|
||||
"""
|
||||
if status_code in _STATUS_TO_ERROR_CODE:
|
||||
return _STATUS_TO_ERROR_CODE[status_code]
|
||||
return OnyxErrorCode.BAD_GATEWAY
|
||||
|
||||
|
||||
class CanvasApiClient:
|
||||
def __init__(
|
||||
self,
|
||||
bearer_token: str,
|
||||
canvas_base_url: str,
|
||||
) -> None:
|
||||
parsed_base = urlparse(canvas_base_url)
|
||||
if not parsed_base.hostname:
|
||||
raise ValueError("canvas_base_url must include a valid host")
|
||||
if parsed_base.scheme != "https":
|
||||
raise ValueError("canvas_base_url must use https")
|
||||
|
||||
self._bearer_token = bearer_token
|
||||
self.base_url = (
|
||||
canvas_base_url.rstrip("/").removesuffix(_CANVAS_API_VERSION)
|
||||
+ _CANVAS_API_VERSION
|
||||
)
|
||||
# Hostname is already validated above; reuse parsed_base instead
|
||||
# of re-parsing. Used by _parse_next_link to validate pagination URLs.
|
||||
self._expected_host: str = parsed_base.hostname
|
||||
|
||||
def get(
|
||||
self,
|
||||
endpoint: str = "",
|
||||
params: dict[str, Any] | None = None,
|
||||
full_url: str | None = None,
|
||||
) -> tuple[Any, str | None]:
|
||||
"""Make a GET request to the Canvas API.
|
||||
|
||||
Returns a tuple of (json_body, next_url).
|
||||
next_url is parsed from the Link header and is None if there are no more pages.
|
||||
If full_url is provided, it is used directly (for following pagination links).
|
||||
|
||||
Security note: full_url must only be set to values returned by
|
||||
``_parse_next_link``, which validates the host against the configured
|
||||
Canvas base URL. Passing an arbitrary URL would leak the bearer token.
|
||||
"""
|
||||
# full_url is used when following pagination (Canvas returns the
|
||||
# next-page URL in the Link header). For the first request we build
|
||||
# the URL from the endpoint name instead.
|
||||
url = full_url if full_url else self._build_url(endpoint)
|
||||
headers = self._build_headers()
|
||||
|
||||
response = rl_requests.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params if not full_url else None,
|
||||
timeout=_CANVAS_CALL_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = response.json()
|
||||
except ValueError as e:
|
||||
if response.status_code < 300:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=f"Invalid JSON in Canvas response: {e}",
|
||||
)
|
||||
logger.warning(
|
||||
"Failed to parse JSON from Canvas error response (status=%d): %s",
|
||||
response.status_code,
|
||||
e,
|
||||
)
|
||||
response_json = {}
|
||||
|
||||
if response.status_code >= 400:
|
||||
# Try to extract the most specific error message from the
|
||||
# Canvas response body. Canvas uses three different shapes
|
||||
# depending on the endpoint and error type:
|
||||
default_error: str = response.reason or f"HTTP {response.status_code}"
|
||||
error = default_error
|
||||
if isinstance(response_json, dict):
|
||||
# Shape 1: {"error": {"message": "Not authorized"}}
|
||||
error_field = response_json.get("error")
|
||||
if isinstance(error_field, dict):
|
||||
response_error = error_field.get("message", "")
|
||||
if response_error:
|
||||
error = response_error
|
||||
# Shape 2: {"error": "Invalid access token"}
|
||||
elif isinstance(error_field, str):
|
||||
error = error_field
|
||||
# Shape 3: {"errors": [{"message": "..."}]}
|
||||
# Used for validation errors. Only use as fallback if
|
||||
# we didn't already find a more specific message above.
|
||||
if error == default_error:
|
||||
errors_list = response_json.get("errors")
|
||||
if isinstance(errors_list, list) and errors_list:
|
||||
first_error = errors_list[0]
|
||||
if isinstance(first_error, dict):
|
||||
msg = first_error.get("message", "")
|
||||
if msg:
|
||||
error = msg
|
||||
raise OnyxError(
|
||||
_error_code_for_status(response.status_code),
|
||||
detail=error,
|
||||
status_code_override=response.status_code,
|
||||
)
|
||||
|
||||
next_url = self._parse_next_link(response.headers.get("Link", ""))
|
||||
return response_json, next_url
|
||||
|
||||
def _parse_next_link(self, link_header: str) -> str | None:
|
||||
"""Extract the 'next' URL from a Canvas Link header.
|
||||
|
||||
Only returns URLs whose host matches the configured Canvas base URL
|
||||
to prevent leaking the bearer token to arbitrary hosts.
|
||||
"""
|
||||
expected_host = self._expected_host
|
||||
for match in _NEXT_LINK_PATTERN.finditer(link_header):
|
||||
url = match.group(1)
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.hostname != expected_host:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=(
|
||||
"Canvas pagination returned an unexpected host "
|
||||
f"({parsed_url.hostname}); expected {expected_host}"
|
||||
),
|
||||
)
|
||||
if parsed_url.scheme != "https":
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=(
|
||||
"Canvas pagination link must use https, "
|
||||
f"got {parsed_url.scheme!r}"
|
||||
),
|
||||
)
|
||||
return url
|
||||
return None
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Return the Authorization header with the bearer token."""
|
||||
return {"Authorization": f"Bearer {self._bearer_token}"}
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
"""Build a full Canvas API URL from an endpoint path.
|
||||
|
||||
Assumes endpoint is non-empty (e.g. ``"courses"``, ``"announcements"``).
|
||||
Only called on a first request, endpoint must be set for first request.
|
||||
Verify endpoint exists in case of future changes where endpoint might be optional.
|
||||
Leading slashes are stripped to avoid double-slash in the result.
|
||||
self.base_url is already normalized with no trailing slash.
|
||||
"""
|
||||
final_url = self.base_url
|
||||
clean_endpoint = endpoint.lstrip("/")
|
||||
if clean_endpoint:
|
||||
final_url += "/" + clean_endpoint
|
||||
return final_url
|
||||
74
backend/onyx/connectors/canvas/connector.py
Normal file
74
backend/onyx/connectors/canvas/connector.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Literal
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
course_code: str
|
||||
created_at: str
|
||||
workflow_state: str
|
||||
|
||||
|
||||
class CanvasPage(BaseModel):
|
||||
page_id: int
|
||||
url: str
|
||||
title: str
|
||||
body: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
course_id: int
|
||||
|
||||
|
||||
class CanvasAssignment(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
html_url: str
|
||||
course_id: int
|
||||
created_at: str
|
||||
updated_at: str
|
||||
due_at: str | None = None
|
||||
|
||||
|
||||
class CanvasAnnouncement(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
message: str | None = None
|
||||
html_url: str
|
||||
posted_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
|
||||
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint state for resumable Canvas indexing.
|
||||
|
||||
Fields:
|
||||
course_ids: Materialized list of course IDs to process.
|
||||
current_course_index: Index into course_ids for current course.
|
||||
stage: Which item type we're processing for the current course.
|
||||
next_url: Pagination cursor within the current stage. None means
|
||||
start from the first page; a URL means resume from that page.
|
||||
|
||||
Invariant:
|
||||
If current_course_index is incremented, stage must be reset to
|
||||
"pages" and next_url must be reset to None.
|
||||
"""
|
||||
|
||||
course_ids: list[int] = []
|
||||
current_course_index: int = 0
|
||||
stage: CanvasStage = "pages"
|
||||
next_url: str | None = None
|
||||
|
||||
def advance_course(self) -> None:
|
||||
"""Move to the next course and reset within-course state."""
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
@@ -890,8 +890,8 @@ class ConfluenceConnector(
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
@@ -915,8 +915,8 @@ class ConfluenceConnector(
|
||||
self.confluence_client, doc_id, restrictions, ancestors
|
||||
) or space_level_access_info.get(page_space_key)
|
||||
|
||||
# Query pages
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
# Query pages (with optional time filtering for indexing_start)
|
||||
page_query = self._construct_page_cql_query(start, end)
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -950,7 +950,9 @@ class ConfluenceConnector(
|
||||
|
||||
# Query attachments for each page
|
||||
page_hierarchy_node_yielded = False
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
|
||||
@@ -123,7 +123,7 @@ class OnyxConfluence:
|
||||
|
||||
self.shared_base_kwargs: dict[str, str | int | bool] = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"backoff_and_retry": False,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
if timeout:
|
||||
@@ -456,7 +456,7 @@ class OnyxConfluence:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
|
||||
@@ -363,7 +363,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
@@ -384,7 +384,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
@@ -408,6 +408,17 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
|
||||
raise e
|
||||
|
||||
if e.response.status_code >= 500:
|
||||
if attempt >= max_retries - 1:
|
||||
raise e
|
||||
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
logger.warning(
|
||||
f"Server error {e.response.status_code}. "
|
||||
f"Retrying in {delay} seconds (attempt {attempt + 1})..."
|
||||
)
|
||||
return math.ceil(time.monotonic() + delay)
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
@@ -239,29 +240,53 @@ def enhanced_search_ids(
|
||||
)
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
def _bulk_fetch_request(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
|
||||
f"be decoded as JSON (response too large or truncated)."
|
||||
)
|
||||
raise
|
||||
|
||||
mid = len(issue_ids) // 2
|
||||
logger.warning(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in response["issues"]
|
||||
for issue in raw_issues
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class NotionPage(BaseModel):
|
||||
id: str
|
||||
created_time: str
|
||||
last_edited_time: str
|
||||
archived: bool
|
||||
in_trash: bool
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
|
||||
@@ -63,6 +63,13 @@ class NotionPage(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class NotionDataSource(BaseModel):
|
||||
"""Represents a Notion Data Source within a database."""
|
||||
|
||||
id: str
|
||||
name: str = ""
|
||||
|
||||
|
||||
class NotionBlock(BaseModel):
|
||||
"""Represents a Notion Block object"""
|
||||
|
||||
@@ -107,7 +114,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
"Notion-Version": "2026-03-11",
|
||||
}
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
@@ -127,6 +134,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
# Maps child page IDs to their containing page ID (discovered in _read_blocks).
|
||||
# Used to resolve block_id parent types to the actual containing page.
|
||||
self._child_page_parent_map: dict[str, str] = {}
|
||||
# Maps data_source_id -> database_id (populated in _read_pages_from_database).
|
||||
# Used to resolve data_source_id parent types back to the database.
|
||||
self._data_source_to_database_map: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@@ -227,7 +237,11 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
|
||||
"""Attempt to fetch a database as a page."""
|
||||
"""Attempt to fetch a database as a page.
|
||||
|
||||
Note: As of API 2025-09-03, database objects no longer include
|
||||
`properties` (schema moved to individual data sources).
|
||||
"""
|
||||
logger.debug(f"Fetching database for ID '{database_id}' as a page")
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
res = rl_requests.get(
|
||||
@@ -246,18 +260,52 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
database_name[0].get("text", {}).get("content") if database_name else None
|
||||
)
|
||||
|
||||
db_data.setdefault("properties", {})
|
||||
|
||||
return NotionPage(**db_data, database_name=database_name)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database(
|
||||
self, database_id: str, cursor: str | None = None
|
||||
def _fetch_data_sources_for_database(
|
||||
self, database_id: str
|
||||
) -> list[NotionDataSource]:
|
||||
"""Fetch the list of data sources for a database."""
|
||||
logger.debug(f"Fetching data sources for database '{database_id}'")
|
||||
res = rl_requests.get(
|
||||
f"https://api.notion.com/v1/databases/{database_id}",
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
if res.status_code in (403, 404):
|
||||
logger.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared "
|
||||
f"with the Onyx integration. Exact exception:\n{e}"
|
||||
)
|
||||
return []
|
||||
logger.exception(f"Error fetching database - {res.json()}")
|
||||
raise e
|
||||
|
||||
db_data = res.json()
|
||||
data_sources = db_data.get("data_sources", [])
|
||||
return [
|
||||
NotionDataSource(id=ds["id"], name=ds.get("name", ""))
|
||||
for ds in data_sources
|
||||
if ds.get("id")
|
||||
]
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_data_source(
|
||||
self, data_source_id: str, cursor: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a database from it's ID via the Notion API."""
|
||||
logger.debug(f"Fetching database for ID '{database_id}'")
|
||||
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
"""Query a data source via POST /v1/data_sources/{id}/query."""
|
||||
logger.debug(f"Querying data source '{data_source_id}'")
|
||||
url = f"https://api.notion.com/v1/data_sources/{data_source_id}/query"
|
||||
body = None if not cursor else {"start_cursor": cursor}
|
||||
res = rl_requests.post(
|
||||
block_url,
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
@@ -265,25 +313,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
json_data = res.json()
|
||||
code = json_data.get("code")
|
||||
# Sep 3 2025 backend changed the error message for this case
|
||||
# TODO: it is also now possible for there to be multiple data sources per database; at present we
|
||||
# just don't handle that. We will need to upgrade the API to the current version + query the
|
||||
# new data sources endpoint to handle that case correctly.
|
||||
if code == "object_not_found" or (
|
||||
code == "validation_error"
|
||||
and "does not contain any data sources" in json_data.get("message", "")
|
||||
):
|
||||
# this happens when a database is not shared with the integration
|
||||
# in this case, we should just ignore the database
|
||||
if res.status_code in (403, 404):
|
||||
logger.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared "
|
||||
f"Unable to access data source with ID '{data_source_id}'. "
|
||||
f"This is likely due to it not being shared "
|
||||
f"with the Onyx integration. Exact exception:\n{e}"
|
||||
)
|
||||
return {"results": [], "next_cursor": None}
|
||||
logger.exception(f"Error fetching database - {res.json()}")
|
||||
logger.exception(f"Error querying data source - {res.json()}")
|
||||
raise e
|
||||
return res.json()
|
||||
|
||||
@@ -348,8 +385,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
# Fallback to workspace if we don't know the parent
|
||||
return self.workspace_id
|
||||
elif parent_type == "data_source_id":
|
||||
# Newer Notion API may use data_source_id for databases
|
||||
return parent.get("database_id") or parent.get("data_source_id")
|
||||
ds_id = parent.get("data_source_id")
|
||||
if ds_id:
|
||||
return self._data_source_to_database_map.get(ds_id, self.workspace_id)
|
||||
elif parent_type in ["page_id", "database_id"]:
|
||||
return parent.get(parent_type)
|
||||
|
||||
@@ -497,18 +535,32 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
if db_node:
|
||||
hierarchy_nodes.append(db_node)
|
||||
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_database(database_id, cursor)
|
||||
# Discover all data sources under this database, then query each one.
|
||||
# Even legacy single-source databases have one entry in the array.
|
||||
data_sources = self._fetch_data_sources_for_database(database_id)
|
||||
if not data_sources:
|
||||
logger.warning(
|
||||
f"Database '{database_id}' returned zero data sources — "
|
||||
f"no pages will be indexed from this database."
|
||||
)
|
||||
for ds in data_sources:
|
||||
self._data_source_to_database_map[ds.id] = database_id
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_data_source(ds.id, cursor)
|
||||
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
text = self._properties_to_str(result.get("properties", {}))
|
||||
if text:
|
||||
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
text = self._properties_to_str(result.get("properties", {}))
|
||||
if text:
|
||||
result_blocks.append(
|
||||
NotionBlock(id=obj_id, text=text, prefix="\n")
|
||||
)
|
||||
|
||||
if not self.recursive_index_enabled:
|
||||
continue
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
if obj_type == "page":
|
||||
logger.debug(
|
||||
f"Found page with ID '{obj_id}' in database '{database_id}'"
|
||||
@@ -518,7 +570,6 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logger.debug(
|
||||
f"Found database with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
# Get nested database name from properties if available
|
||||
nested_db_title = result.get("title", [])
|
||||
nested_db_name = None
|
||||
if nested_db_title and len(nested_db_title) > 0:
|
||||
@@ -533,10 +584,10 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
result_pages.extend(nested_output.child_page_ids)
|
||||
hierarchy_nodes.extend(nested_output.hierarchy_nodes)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return BlockReadOutput(
|
||||
blocks=result_blocks,
|
||||
@@ -807,36 +858,55 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
def _yield_database_hierarchy_nodes(
|
||||
self,
|
||||
) -> Generator[HierarchyNode | Document, None, None]:
|
||||
"""Search for all databases and yield hierarchy nodes for each.
|
||||
"""Search for all data sources and yield hierarchy nodes for their parent databases.
|
||||
|
||||
This must be called BEFORE page indexing so that database hierarchy nodes
|
||||
exist when pages inside databases reference them as parents.
|
||||
|
||||
With the new API, search returns data source objects instead of databases.
|
||||
Multiple data sources can share the same parent database, so we use
|
||||
database_id as the hierarchy node key and deduplicate via
|
||||
_maybe_yield_hierarchy_node.
|
||||
"""
|
||||
query_dict: dict[str, Any] = {
|
||||
"filter": {"property": "object", "value": "database"},
|
||||
"filter": {"property": "object", "value": "data_source"},
|
||||
"page_size": _NOTION_PAGE_SIZE,
|
||||
}
|
||||
pages_seen = 0
|
||||
while pages_seen < _MAX_PAGES:
|
||||
db_res = self._search_notion(query_dict)
|
||||
for db in db_res.results:
|
||||
db_id = db["id"]
|
||||
# Extract title from the title array
|
||||
title_arr = db.get("title", [])
|
||||
db_name = None
|
||||
if title_arr:
|
||||
db_name = " ".join(
|
||||
t.get("plain_text", "") for t in title_arr
|
||||
).strip()
|
||||
if not db_name:
|
||||
for ds in db_res.results:
|
||||
# Extract the parent database_id from the data source's parent
|
||||
ds_parent = ds.get("parent", {})
|
||||
db_id = ds_parent.get("database_id")
|
||||
if not db_id:
|
||||
continue
|
||||
|
||||
# Populate the mapping so _get_parent_raw_id can resolve later
|
||||
ds_id = ds.get("id")
|
||||
if not ds_id:
|
||||
continue
|
||||
self._data_source_to_database_map[ds_id] = db_id
|
||||
|
||||
# Fetch the database to get its actual name and parent
|
||||
try:
|
||||
db_page = self._fetch_database_as_page(db_id)
|
||||
db_name = db_page.database_name or f"Database {db_id}"
|
||||
parent_raw_id = self._get_parent_raw_id(db_page.parent)
|
||||
db_url = (
|
||||
db_page.url or f"https://notion.so/{db_id.replace('-', '')}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(
|
||||
f"Could not fetch database '{db_id}', "
|
||||
f"defaulting to workspace root. Error: {e}"
|
||||
)
|
||||
db_name = f"Database {db_id}"
|
||||
parent_raw_id = self.workspace_id
|
||||
db_url = f"https://notion.so/{db_id.replace('-', '')}"
|
||||
|
||||
# Get parent using existing helper
|
||||
parent_raw_id = self._get_parent_raw_id(db.get("parent"))
|
||||
|
||||
# Notion URLs omit dashes from UUIDs
|
||||
db_url = db.get("url") or f"https://notion.so/{db_id.replace('-', '')}"
|
||||
|
||||
# _maybe_yield_hierarchy_node deduplicates by raw_node_id,
|
||||
# so multiple data sources under one database produce one node.
|
||||
node = self._maybe_yield_hierarchy_node(
|
||||
raw_node_id=db_id,
|
||||
raw_parent_id=parent_raw_id or self.workspace_id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import copy
|
||||
import fnmatch
|
||||
import html
|
||||
import io
|
||||
import os
|
||||
@@ -84,6 +85,44 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
|
||||
|
||||
ASPX_EXTENSION = ".aspx"
|
||||
|
||||
|
||||
def _is_site_excluded(site_url: str, excluded_site_patterns: list[str]) -> bool:
|
||||
"""Check if a site URL matches any of the exclusion glob patterns."""
|
||||
for pattern in excluded_site_patterns:
|
||||
if fnmatch.fnmatch(site_url, pattern) or fnmatch.fnmatch(
|
||||
site_url.rstrip("/"), pattern.rstrip("/")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_path_excluded(item_path: str, excluded_path_patterns: list[str]) -> bool:
|
||||
"""Check if a drive item path matches any of the exclusion glob patterns.
|
||||
|
||||
item_path is the relative path within a drive, e.g. "Engineering/API/report.docx".
|
||||
Matches are attempted against the full path and the filename alone so that
|
||||
patterns like "*.tmp" match files at any depth.
|
||||
"""
|
||||
filename = item_path.rsplit("/", 1)[-1] if "/" in item_path else item_path
|
||||
for pattern in excluded_path_patterns:
|
||||
if fnmatch.fnmatch(item_path, pattern) or fnmatch.fnmatch(filename, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_item_relative_path(parent_reference_path: str | None, item_name: str) -> str:
|
||||
"""Build the relative path of a drive item from its parentReference.path and name.
|
||||
|
||||
Example: parentReference.path="/drives/abc/root:/Eng/API", name="report.docx"
|
||||
=> "Eng/API/report.docx"
|
||||
"""
|
||||
if parent_reference_path and "root:/" in parent_reference_path:
|
||||
folder = unquote(parent_reference_path.split("root:/", 1)[1])
|
||||
if folder:
|
||||
return f"{folder}/{item_name}"
|
||||
return item_name
|
||||
|
||||
|
||||
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
|
||||
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
|
||||
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
|
||||
@@ -478,6 +517,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
include_permissions: bool = False,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
access_token: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
|
||||
if not driveitem.name or not driveitem.id:
|
||||
@@ -610,6 +650,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
drive_item=sdk_item,
|
||||
drive_name=drive_name,
|
||||
add_prefix=True,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
else:
|
||||
external_access = ExternalAccess.empty()
|
||||
@@ -644,6 +685,7 @@ def _convert_sitepage_to_document(
|
||||
graph_client: GraphClient,
|
||||
include_permissions: bool = False,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> Document:
|
||||
"""Convert a SharePoint site page to a Document object."""
|
||||
# Extract text content from the site page
|
||||
@@ -773,6 +815,7 @@ def _convert_sitepage_to_document(
|
||||
graph_client=graph_client,
|
||||
site_page=site_page,
|
||||
add_prefix=True,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
else:
|
||||
external_access = ExternalAccess.empty()
|
||||
@@ -803,6 +846,7 @@ def _convert_driveitem_to_slim_document(
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> SlimDocument:
|
||||
if driveitem.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -813,6 +857,7 @@ def _convert_driveitem_to_slim_document(
|
||||
graph_client=graph_client,
|
||||
drive_item=sdk_item,
|
||||
drive_name=drive_name,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
return SlimDocument(
|
||||
@@ -827,6 +872,7 @@ def _convert_sitepage_to_slim_document(
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> SlimDocument:
|
||||
"""Convert a SharePoint site page to a SlimDocument object."""
|
||||
if site_page.get("id") is None:
|
||||
@@ -836,6 +882,7 @@ def _convert_sitepage_to_slim_document(
|
||||
ctx=ctx,
|
||||
graph_client=graph_client,
|
||||
site_page=site_page,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
id = site_page.get("id")
|
||||
if id is None:
|
||||
@@ -855,14 +902,20 @@ class SharepointConnector(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
sites: list[str] = [],
|
||||
excluded_sites: list[str] = [],
|
||||
excluded_paths: list[str] = [],
|
||||
include_site_pages: bool = True,
|
||||
include_site_documents: bool = True,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
authority_host: str = DEFAULT_AUTHORITY_HOST,
|
||||
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
|
||||
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sites = list(sites)
|
||||
self.excluded_sites = [s for p in excluded_sites if (s := p.strip())]
|
||||
self.excluded_paths = [s for p in excluded_paths if (s := p.strip())]
|
||||
self.treat_sharing_link_as_public = treat_sharing_link_as_public
|
||||
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
|
||||
sites
|
||||
)
|
||||
@@ -1233,6 +1286,29 @@ class SharepointConnector(
|
||||
break
|
||||
sites = sites._get_next().execute_query()
|
||||
|
||||
def _is_driveitem_excluded(self, driveitem: DriveItemData) -> bool:
|
||||
"""Check if a drive item should be excluded based on excluded_paths patterns."""
|
||||
if not self.excluded_paths:
|
||||
return False
|
||||
relative_path = _build_item_relative_path(
|
||||
driveitem.parent_reference_path, driveitem.name
|
||||
)
|
||||
return _is_path_excluded(relative_path, self.excluded_paths)
|
||||
|
||||
def _filter_excluded_sites(
|
||||
self, site_descriptors: list[SiteDescriptor]
|
||||
) -> list[SiteDescriptor]:
|
||||
"""Remove sites matching any excluded_sites glob pattern."""
|
||||
if not self.excluded_sites:
|
||||
return site_descriptors
|
||||
result = []
|
||||
for sd in site_descriptors:
|
||||
if _is_site_excluded(sd.url, self.excluded_sites):
|
||||
logger.info(f"Excluding site by denylist: {sd.url}")
|
||||
continue
|
||||
result.append(sd)
|
||||
return result
|
||||
|
||||
def fetch_sites(self) -> list[SiteDescriptor]:
|
||||
sites = self.graph_client.sites.get_all_sites().execute_query()
|
||||
|
||||
@@ -1249,7 +1325,7 @@ class SharepointConnector(
|
||||
for site in self._handle_paginated_sites(sites)
|
||||
if "-my.sharepoint" not in site.web_url
|
||||
]
|
||||
return site_descriptors
|
||||
return self._filter_excluded_sites(site_descriptors)
|
||||
|
||||
def _fetch_site_pages(
|
||||
self,
|
||||
@@ -1689,8 +1765,14 @@ class SharepointConnector(
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
def _fetch_slim_documents_from_sharepoint(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
|
||||
# Create a temporary checkpoint for hierarchy node tracking
|
||||
temp_checkpoint = SharepointConnectorCheckpoint(has_more=True)
|
||||
@@ -1708,8 +1790,14 @@ class SharepointConnector(
|
||||
# Process site documents if flag is True
|
||||
if self.include_site_documents:
|
||||
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
|
||||
site_descriptor=site_descriptor
|
||||
site_descriptor=site_descriptor,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
continue
|
||||
|
||||
if drive_web_url:
|
||||
doc_batch.extend(
|
||||
self._yield_drive_hierarchy_node(
|
||||
@@ -1747,6 +1835,7 @@ class SharepointConnector(
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1758,7 +1847,9 @@ class SharepointConnector(
|
||||
|
||||
# Process site pages if flag is True
|
||||
if self.include_site_pages:
|
||||
site_pages = self._fetch_site_pages(site_descriptor)
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start, end=end
|
||||
)
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
@@ -1770,6 +1861,7 @@ class SharepointConnector(
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
@@ -2043,7 +2135,9 @@ class SharepointConnector(
|
||||
and not checkpoint.process_site_pages
|
||||
):
|
||||
logger.info("Initializing SharePoint sites for processing")
|
||||
site_descs = self.site_descriptors or self.fetch_sites()
|
||||
site_descs = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
checkpoint.cached_site_descriptors = deque(site_descs)
|
||||
|
||||
if not checkpoint.cached_site_descriptors:
|
||||
@@ -2264,6 +2358,10 @@ class SharepointConnector(
|
||||
for driveitem in driveitems:
|
||||
item_count += 1
|
||||
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
continue
|
||||
|
||||
if driveitem.id and driveitem.id in checkpoint.seen_document_ids:
|
||||
logger.debug(
|
||||
f"Skipping duplicate document {driveitem.id} ({driveitem.name})"
|
||||
@@ -2318,6 +2416,7 @@ class SharepointConnector(
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
graph_api_base=self.graph_api_base,
|
||||
access_token=access_token,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
if isinstance(doc_or_failure, Document):
|
||||
@@ -2398,6 +2497,7 @@ class SharepointConnector(
|
||||
include_permissions=include_permissions,
|
||||
# Site pages have the site as their parent
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
@@ -2473,12 +2573,22 @@ class SharepointConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
|
||||
yield from self._fetch_slim_documents_from_sharepoint()
|
||||
start_dt = (
|
||||
datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
if start is not None
|
||||
else None
|
||||
)
|
||||
end_dt = (
|
||||
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
|
||||
)
|
||||
yield from self._fetch_slim_documents_from_sharepoint(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,6 +17,7 @@ def get_sharepoint_external_access(
|
||||
drive_name: str | None = None,
|
||||
site_page: dict[str, Any] | None = None,
|
||||
add_prefix: bool = False,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> ExternalAccess:
|
||||
if drive_item and drive_item.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -34,7 +35,13 @@ def get_sharepoint_external_access(
|
||||
)
|
||||
|
||||
external_access = get_external_access_func(
|
||||
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
|
||||
ctx,
|
||||
graph_client,
|
||||
drive_name,
|
||||
drive_item,
|
||||
site_page,
|
||||
add_prefix,
|
||||
treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
return external_access
|
||||
|
||||
@@ -516,6 +516,8 @@ def _get_all_doc_ids(
|
||||
] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
workspace_url: str | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -546,6 +548,8 @@ def _get_all_doc_ids(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
oldest=str(start) if start else None, # 0.0 -> None intentionally
|
||||
latest=str(end) if end is not None else None,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
@@ -847,8 +851,8 @@ class SlackConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
@@ -861,6 +865,8 @@ class SlackConnector(
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
callback=callback,
|
||||
workspace_url=self._workspace_url,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _load_from_checkpoint(
|
||||
|
||||
@@ -401,3 +401,16 @@ class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class PersonaSearchInfo(BaseModel):
|
||||
"""Snapshot of persona data needed by the search pipeline.
|
||||
|
||||
Extracted from the ORM Persona before the DB session is released so that
|
||||
SearchTool and search_pipeline never lazy-load relationships post-commit.
|
||||
"""
|
||||
|
||||
document_set_names: list[str]
|
||||
search_start_date: datetime | None
|
||||
attached_document_ids: list[str]
|
||||
hierarchy_node_ids: list[int]
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.context.search.retrieval.search_runner import search_chunks
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
@@ -247,8 +247,8 @@ def search_pipeline(
|
||||
document_index: DocumentIndex,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
# Pre-extracted persona search configuration (None when no persona)
|
||||
persona_search_info: PersonaSearchInfo | None,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
@@ -263,24 +263,18 @@ def search_pipeline(
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
persona_document_sets: list[str] | None = (
|
||||
[persona_document_set.name for persona_document_set in persona.document_sets]
|
||||
if persona
|
||||
else None
|
||||
persona_search_info.document_set_names if persona_search_info else None
|
||||
)
|
||||
persona_time_cutoff: datetime | None = (
|
||||
persona.search_start_date if persona else None
|
||||
persona_search_info.search_start_date if persona_search_info else None
|
||||
)
|
||||
|
||||
# Extract assistant knowledge filters from persona
|
||||
attached_document_ids: list[str] | None = (
|
||||
[doc.id for doc in persona.attached_documents]
|
||||
if persona and persona.attached_documents
|
||||
persona_search_info.attached_document_ids or None
|
||||
if persona_search_info
|
||||
else None
|
||||
)
|
||||
hierarchy_node_ids: list[int] | None = (
|
||||
[node.id for node in persona.hierarchy_nodes]
|
||||
if persona and persona.hierarchy_nodes
|
||||
else None
|
||||
persona_search_info.hierarchy_node_ids or None if persona_search_info else None
|
||||
)
|
||||
|
||||
filters = _build_index_filters(
|
||||
|
||||
@@ -14,6 +14,10 @@ from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces_new import DocumentIndex as NewDocumentIndex
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
@@ -49,7 +53,7 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
def _embed_and_search(
|
||||
def _embed_and_hybrid_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session | None = None,
|
||||
@@ -81,6 +85,17 @@ def _embed_and_search(
|
||||
return top_chunks
|
||||
|
||||
|
||||
def _keyword_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: NewDocumentIndex,
|
||||
) -> list[InferenceChunk]:
|
||||
return document_index.keyword_retrieval(
|
||||
query=query_request.query,
|
||||
filters=query_request.filters,
|
||||
num_to_retrieve=query_request.limit or NUM_RETURNED_HITS,
|
||||
)
|
||||
|
||||
|
||||
def search_chunks(
|
||||
query_request: ChunkIndexRequest,
|
||||
user_id: UUID | None,
|
||||
@@ -128,21 +143,38 @@ def search_chunks(
|
||||
)
|
||||
|
||||
if normal_search_enabled:
|
||||
run_queries.append(
|
||||
(
|
||||
_embed_and_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
if (
|
||||
query_request.hybrid_alpha is not None
|
||||
and query_request.hybrid_alpha == 0.0
|
||||
and isinstance(document_index, OpenSearchOldDocumentIndex)
|
||||
):
|
||||
# If hybrid alpha is explicitly set to keyword only, do pure keyword
|
||||
# search without generating an embedding. This is currently only
|
||||
# supported with OpenSearchDocumentIndex.
|
||||
opensearch_new_document_index: NewDocumentIndex = document_index._real_index
|
||||
run_queries.append(
|
||||
(
|
||||
lambda: _keyword_search(
|
||||
query_request, opensearch_new_document_index
|
||||
),
|
||||
(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
run_queries.append(
|
||||
(
|
||||
_embed_and_hybrid_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
logger.debug(
|
||||
f"Hybrid search returned no results for query: {query_request.query}with filters: {query_request.filters}"
|
||||
f"Search returned no results for query: {query_request.query} with filters: {query_request.filters}."
|
||||
)
|
||||
return []
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
@@ -64,6 +64,9 @@ def get_chat_session_by_id(
|
||||
joinedload(ChatSession.persona).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
),
|
||||
joinedload(ChatSession.project),
|
||||
)
|
||||
|
||||
@@ -750,3 +750,31 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -1,4 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class AccountType(str, PyEnum):
|
||||
"""
|
||||
What kind of account this is — determines whether the user
|
||||
enters the group-based permission system.
|
||||
|
||||
STANDARD + SERVICE_ACCOUNT → participate in group system
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@@ -314,3 +341,54 @@ class HookPoint(str, PyEnum):
|
||||
class HookFailStrategy(str, PyEnum):
|
||||
HARD = "hard" # exception propagates, pipeline aborts
|
||||
SOFT = "soft" # log error, return original input, pipeline continues
|
||||
|
||||
|
||||
class Permission(str, PyEnum):
|
||||
"""
|
||||
Permission tokens for group-based authorization.
|
||||
19 tokens total. full_admin_panel_access is an override —
|
||||
if present, any permission check passes.
|
||||
"""
|
||||
|
||||
# Basic (auto-granted to every new group)
|
||||
BASIC_ACCESS = "basic"
|
||||
|
||||
# Read tokens — implied only, never granted directly
|
||||
READ_CONNECTORS = "read:connectors"
|
||||
READ_DOCUMENT_SETS = "read:document_sets"
|
||||
READ_AGENTS = "read:agents"
|
||||
READ_USERS = "read:users"
|
||||
|
||||
# Add / Manage pairs
|
||||
ADD_AGENTS = "add:agents"
|
||||
MANAGE_AGENTS = "manage:agents"
|
||||
MANAGE_DOCUMENT_SETS = "manage:document_sets"
|
||||
ADD_CONNECTORS = "add:connectors"
|
||||
MANAGE_CONNECTORS = "manage:connectors"
|
||||
MANAGE_LLMS = "manage:llms"
|
||||
|
||||
# Toggle tokens
|
||||
READ_AGENT_ANALYTICS = "read:agent_analytics"
|
||||
MANAGE_ACTIONS = "manage:actions"
|
||||
READ_QUERY_HISTORY = "read:query_history"
|
||||
MANAGE_USER_GROUPS = "manage:user_groups"
|
||||
CREATE_USER_API_KEYS = "create:user_api_keys"
|
||||
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
|
||||
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
|
||||
|
||||
# Override — any permission check passes
|
||||
FULL_ADMIN_PANEL_ACCESS = "admin"
|
||||
|
||||
# Permissions that are implied by other grants and must never be stored
|
||||
# directly in the permission_grant table.
|
||||
IMPLIED: ClassVar[frozenset[Permission]]
|
||||
|
||||
|
||||
Permission.IMPLIED = frozenset(
|
||||
{
|
||||
Permission.READ_CONNECTORS,
|
||||
Permission.READ_DOCUMENT_SETS,
|
||||
Permission.READ_AGENTS,
|
||||
Permission.READ_USERS,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -75,6 +75,7 @@ def create_hook__no_commit(
|
||||
fail_strategy: HookFailStrategy,
|
||||
timeout_seconds: float,
|
||||
is_active: bool = False,
|
||||
is_reachable: bool | None = None,
|
||||
creator_id: UUID | None = None,
|
||||
) -> Hook:
|
||||
"""Create a new hook for the given hook point.
|
||||
@@ -100,6 +101,7 @@ def create_hook__no_commit(
|
||||
fail_strategy=fail_strategy,
|
||||
timeout_seconds=timeout_seconds,
|
||||
is_active=is_active,
|
||||
is_reachable=is_reachable,
|
||||
creator_id=creator_id,
|
||||
)
|
||||
# Use a savepoint so that a failed insert only rolls back this operation,
|
||||
|
||||
@@ -2,6 +2,8 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -28,6 +30,9 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
@@ -972,3 +977,106 @@ def get_index_attempt_errors_for_cc_pair(
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActiveIndexAttemptMetric(NamedTuple):
|
||||
"""Row returned by get_active_index_attempts_for_metrics."""
|
||||
|
||||
status: IndexingStatus
|
||||
source: "DocumentSource"
|
||||
cc_pair_id: int
|
||||
cc_pair_name: str | None
|
||||
attempt_count: int
|
||||
|
||||
|
||||
def get_active_index_attempts_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list[ActiveIndexAttemptMetric]:
|
||||
"""Return non-terminal index attempts grouped by status, source, and connector.
|
||||
|
||||
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
|
||||
"""
|
||||
from onyx.db.models import Connector
|
||||
|
||||
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
func.count(),
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(IndexAttempt.status.notin_(terminal_statuses))
|
||||
.group_by(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return [ActiveIndexAttemptMetric(*row) for row in rows]
|
||||
|
||||
|
||||
def get_failed_attempt_counts_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
|
||||
|
||||
When ``since`` is provided, only attempts created after that timestamp
|
||||
are counted. Defaults to the last 90 days to avoid unbounded historical
|
||||
aggregation.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.count(),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.FAILED)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.all()
|
||||
)
|
||||
return {cc_id: count for cc_id, count in rows}
|
||||
|
||||
|
||||
def get_docs_indexed_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
|
||||
|
||||
Only counts attempts with status SUCCESS to avoid inflating counts with
|
||||
partial results from failed attempts. When ``since`` is provided, only
|
||||
attempts created after that timestamp are included.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
query = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
)
|
||||
rows = query.all()
|
||||
return {cc_id: int(total or 0) for cc_id, total in rows}
|
||||
|
||||
@@ -48,6 +48,7 @@ from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import (
|
||||
ANONYMOUS_USER_UUID,
|
||||
@@ -78,6 +79,8 @@ from onyx.db.enums import (
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
MCPServerStatus,
|
||||
Permission,
|
||||
GrantSource,
|
||||
LLMModelFlowType,
|
||||
ThemePreference,
|
||||
DefaultAppMode,
|
||||
@@ -302,6 +305,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
Preferences probably should be in a separate table at some point, but for now
|
||||
@@ -2645,6 +2651,15 @@ class ChatMessage(Base):
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# For multi-model turns: the user message points to which assistant response
|
||||
# was selected as the preferred one to continue the conversation with.
|
||||
preferred_response_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
# The display name of the model that generated this assistant message
|
||||
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# What does this message contain
|
||||
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
@@ -2712,6 +2727,12 @@ class ChatMessage(Base):
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
preferred_response: Mapped["ChatMessage | None"] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[preferred_response_id],
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
# Chat messages only need to know their immediate tool call children
|
||||
# If there are nested tool calls, they are stored in the tool_call_children relationship.
|
||||
tool_calls: Mapped[list["ToolCall"] | None] = relationship(
|
||||
@@ -3114,8 +3135,6 @@ class VoiceProvider(Base):
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
@@ -3971,6 +3990,8 @@ class SamlAccount(Base):
|
||||
class User__UserGroup(Base):
|
||||
__tablename__ = "user__user_group"
|
||||
|
||||
__table_args__ = (Index("ix_user__user_group_user_id", "user_id"),)
|
||||
|
||||
is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
@@ -3981,6 +4002,48 @@ class User__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class PermissionGrant(Base):
|
||||
__tablename__ = "permission_grant"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
)
|
||||
granted_by: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
granted_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
)
|
||||
|
||||
group: Mapped["UserGroup"] = relationship(
|
||||
"UserGroup", back_populates="permission_grants"
|
||||
)
|
||||
|
||||
@validates("permission")
|
||||
def _validate_permission(self, _key: str, value: Permission) -> Permission:
|
||||
if value in Permission.IMPLIED:
|
||||
raise ValueError(
|
||||
f"{value!r} is an implied permission and cannot be granted directly"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
__tablename__ = "user_group__connector_credential_pair"
|
||||
|
||||
@@ -4075,6 +4138,8 @@ class UserGroup(Base):
|
||||
is_up_for_deletion: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
# whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Last time a user updated this user group
|
||||
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -4118,6 +4183,9 @@ class UserGroup(Base):
|
||||
accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship(
|
||||
"MCPServer", secondary="mcp_server__user_group", back_populates="user_groups"
|
||||
)
|
||||
permission_grants: Mapped[list["PermissionGrant"]] = relationship(
|
||||
"PermissionGrant", back_populates="group", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Token Rate Limiting
|
||||
|
||||
@@ -50,8 +50,18 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_default_behavior_persona(db_session: Session) -> Persona | None:
|
||||
def get_default_behavior_persona(
|
||||
db_session: Session,
|
||||
eager_load_for_tools: bool = False,
|
||||
) -> Persona | None:
|
||||
stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID)
|
||||
if eager_load_for_tools:
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
|
||||
@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
@@ -119,10 +108,10 @@ def upsert_voice_provider(
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
|
||||
it is done automatically outside of this code.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- chunks: Document chunks with all of the information needed for
|
||||
indexing to the document index.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
This is often a batch operation including chunks from multiple
|
||||
documents.
|
||||
@@ -381,6 +382,47 @@ class HybridCapable(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Runs keyword-only search and returns a list of inference chunks.
|
||||
|
||||
Args:
|
||||
query: User query.
|
||||
filters: Filters for things like permissions, source type, time,
|
||||
etc.
|
||||
num_to_retrieve: Number of highest matching chunks to return.
|
||||
|
||||
Returns:
|
||||
Score-ranked (highest first) list of highest matching chunks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Runs semantic-only search and returns a list of inference chunks.
|
||||
|
||||
Args:
|
||||
query_embedding: Vector representation of the query. Must be of the
|
||||
correct dimensionality for the primary index.
|
||||
filters: Filters for things like permissions, source type, time,
|
||||
etc.
|
||||
num_to_retrieve: Number of highest matching chunks to return.
|
||||
|
||||
Returns:
|
||||
Score-ranked (highest first) list of highest matching chunks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""
|
||||
|
||||
@@ -18,10 +18,13 @@ from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
|
||||
from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
|
||||
from onyx.server.metrics.opensearch_search import observe_opensearch_search
|
||||
from onyx.server.metrics.opensearch_search import track_opensearch_search_in_progress
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
@@ -256,7 +259,6 @@ class OpenSearchClient(AbstractContextManager):
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
@@ -304,6 +306,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
emit_metrics: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
@@ -315,6 +318,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
self._emit_metrics = emit_metrics
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -834,7 +838,10 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_pipeline_id: str | None,
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
|
||||
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
|
||||
"""Searches the index.
|
||||
|
||||
@@ -852,6 +859,8 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
documentation for more information on search request bodies.
|
||||
search_pipeline_id: The ID of the search pipeline to use. If None,
|
||||
the default search pipeline will be used.
|
||||
search_type: Label for Prometheus metrics. Does not affect search
|
||||
behavior.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error searching the index.
|
||||
@@ -864,21 +873,27 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
)
|
||||
result: dict[str, Any]
|
||||
params = {"phase_took": "true"}
|
||||
if search_pipeline_id:
|
||||
result = self._client.search(
|
||||
index=self._index_name,
|
||||
search_pipeline=search_pipeline_id,
|
||||
body=body,
|
||||
params=params,
|
||||
)
|
||||
else:
|
||||
result = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
ctx = self._get_emit_metrics_context_manager(search_type)
|
||||
t0 = time.perf_counter()
|
||||
with ctx:
|
||||
if search_pipeline_id:
|
||||
result = self._client.search(
|
||||
index=self._index_name,
|
||||
search_pipeline=search_pipeline_id,
|
||||
body=body,
|
||||
params=params,
|
||||
)
|
||||
else:
|
||||
result = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
client_duration_s = time.perf_counter() - t0
|
||||
|
||||
hits, time_took, timed_out, phase_took, profile = (
|
||||
self._get_hits_and_profile_from_search_result(result)
|
||||
)
|
||||
if self._emit_metrics:
|
||||
observe_opensearch_search(search_type, client_duration_s, time_took)
|
||||
self._log_search_result_perf(
|
||||
time_took=time_took,
|
||||
timed_out=timed_out,
|
||||
@@ -914,7 +929,11 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
return search_hits
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search_for_document_ids(self, body: dict[str, Any]) -> list[str]:
|
||||
def search_for_document_ids(
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
|
||||
) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
In order to take advantage of the performance benefits of only returning
|
||||
@@ -931,6 +950,8 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
documentation for more information on search request bodies.
|
||||
TODO(andrei): Make this a more deep interface; callers shouldn't
|
||||
need to know to set _source: False for example.
|
||||
search_type: Label for Prometheus metrics. Does not affect search
|
||||
behavior.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error searching the index.
|
||||
@@ -948,13 +969,19 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
)
|
||||
|
||||
params = {"phase_took": "true"}
|
||||
result: dict[str, Any] = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
ctx = self._get_emit_metrics_context_manager(search_type)
|
||||
t0 = time.perf_counter()
|
||||
with ctx:
|
||||
result: dict[str, Any] = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
client_duration_s = time.perf_counter() - t0
|
||||
|
||||
hits, time_took, timed_out, phase_took, profile = (
|
||||
self._get_hits_and_profile_from_search_result(result)
|
||||
)
|
||||
if self._emit_metrics:
|
||||
observe_opensearch_search(search_type, client_duration_s, time_took)
|
||||
self._log_search_result_perf(
|
||||
time_took=time_took,
|
||||
timed_out=timed_out,
|
||||
@@ -1071,6 +1098,20 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
if raise_on_timeout:
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
def _get_emit_metrics_context_manager(
|
||||
self, search_type: OpenSearchSearchType
|
||||
) -> AbstractContextManager[None]:
|
||||
"""
|
||||
Returns a context manager that tracks in-flight OpenSearch searches via
|
||||
a Gauge if emit_metrics is True, otherwise returns a null context
|
||||
manager.
|
||||
"""
|
||||
return (
|
||||
track_opensearch_search_in_progress(search_type)
|
||||
if self._emit_metrics
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def wait_for_opensearch_with_timeout(
|
||||
wait_interval_s: int = 5,
|
||||
|
||||
@@ -53,6 +53,18 @@ DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
|
||||
|
||||
|
||||
class OpenSearchSearchType(str, Enum):
|
||||
"""Search type label used for Prometheus metrics."""
|
||||
|
||||
HYBRID = "hybrid"
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
RANDOM = "random"
|
||||
ID_RETRIEVAL = "id_retrieval"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HybridSearchSubqueryConfiguration(Enum):
|
||||
TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1
|
||||
# Current default.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -43,6 +43,7 @@ from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
@@ -350,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -646,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata, # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
Groups chunks by document ID and for each document, deletes existing
|
||||
chunks and indexes the new chunks in bulk.
|
||||
@@ -672,29 +673,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document is newly indexed or had already existed and was just
|
||||
updated.
|
||||
"""
|
||||
# Group chunks by document ID.
|
||||
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
|
||||
list
|
||||
total_chunks = sum(
|
||||
cc.new_chunk_cnt
|
||||
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
|
||||
)
|
||||
for chunk in chunks:
|
||||
doc_id_to_chunks[chunk.source_document.id].append(chunk)
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
|
||||
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
|
||||
f"documents for index {self._index_name}."
|
||||
)
|
||||
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
# Try to index per-document.
|
||||
for _, chunks in doc_id_to_chunks.items():
|
||||
deleted_doc_ids: set[str] = set()
|
||||
# Buffer chunks per document as they arrive from the iterable.
|
||||
# When the document ID changes flush the buffered chunks.
|
||||
current_doc_id: str | None = None
|
||||
current_chunks: list[DocMetadataAwareIndexChunk] = []
|
||||
|
||||
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
|
||||
assert len(doc_chunks) > 0, "doc_chunks is empty"
|
||||
|
||||
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
|
||||
# Do this before deleting existing chunks to reduce the amount of
|
||||
# time the document index has no content for a given document, and
|
||||
# to reduce the chance of entering a state where we delete chunks,
|
||||
# then some error happens, and never successfully index new chunks.
|
||||
# Since we are doing this in batches, an error occurring midway
|
||||
# can result in a state where chunks are deleted and not all the
|
||||
# new chunks have been indexed.
|
||||
chunk_batch: list[DocumentChunk] = [
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk)
|
||||
for chunk in doc_chunks
|
||||
]
|
||||
onyx_document: Document = chunks[0].source_document
|
||||
onyx_document: Document = doc_chunks[0].source_document
|
||||
# First delete the doc's chunks from the index. This is so that
|
||||
# there are no dangling chunks in the index, in the event that the
|
||||
# new document's content contains fewer chunks than the previous
|
||||
@@ -703,22 +709,40 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# if the chunk count has actually decreased. This assumes that
|
||||
# overlapping chunks are perfectly overwritten. If we can't
|
||||
# guarantee that then we need the code as-is.
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed.
|
||||
document_insertion_record = DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
if onyx_document.id not in deleted_doc_ids:
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
deleted_doc_ids.add(onyx_document.id)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed. We record the result before bulk_index_documents
|
||||
# runs. If indexing raises, this entire result list is discarded
|
||||
# by the caller's retry logic, so early recording is safe.
|
||||
document_indexing_results.append(
|
||||
DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
document_indexing_results.append(document_insertion_record)
|
||||
|
||||
for chunk in chunks:
|
||||
doc_id = chunk.source_document.id
|
||||
if doc_id != current_doc_id:
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
current_doc_id = doc_id
|
||||
current_chunks = [chunk]
|
||||
else:
|
||||
current_chunks.append(chunk)
|
||||
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
|
||||
return document_indexing_results
|
||||
|
||||
@@ -900,6 +924,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.ID_RETRIEVAL,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
@@ -923,6 +948,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
@@ -948,6 +975,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=normalization_pipeline_name,
|
||||
search_type=OpenSearchSearchType.HYBRID,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have
|
||||
@@ -970,6 +998,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
@@ -989,6 +1019,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.KEYWORD,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
@@ -1009,6 +1040,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
@@ -1028,6 +1061,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.SEMANTIC,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
@@ -1059,6 +1093,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.RANDOM,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
|
||||
@@ -3,6 +3,8 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
|
||||
@@ -48,13 +50,21 @@ from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
# of the weights should sum to 1.
|
||||
# See https://docs.opensearch.org/latest/query-dsl/term/terms/.
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY = 65_536
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
TermsQuery: TypeAlias = dict[str, dict[str, list[_T]]]
|
||||
TermQuery: TypeAlias = dict[str, dict[str, dict[str, _T]]]
|
||||
|
||||
|
||||
# TODO(andrei): Turn all magic dictionaries to pydantic models.
|
||||
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
# of the weights should sum to 1.
|
||||
def _get_hybrid_search_normalization_weights() -> list[float]:
|
||||
if (
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
|
||||
@@ -316,6 +326,9 @@ class DocumentQuery:
|
||||
it MUST be supplied in addition to a search pipeline. The results from
|
||||
hybrid search are not meaningful without that step.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
@@ -419,6 +432,9 @@ class DocumentQuery:
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
@@ -498,6 +514,9 @@ class DocumentQuery:
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_embedding: The vector embedding of the text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
@@ -763,8 +782,9 @@ class DocumentQuery:
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"operator": "or",
|
||||
# The title fields are strongly discounted as they are included in the content.
|
||||
# It just acts as a minor boost
|
||||
# The title fields are strongly discounted as
|
||||
# they are included in the content. This just
|
||||
# acts as a minor boost.
|
||||
"boost": 0.1,
|
||||
}
|
||||
}
|
||||
@@ -779,6 +799,9 @@ class DocumentQuery:
|
||||
}
|
||||
},
|
||||
{
|
||||
# Analyzes the query and returns results which match any
|
||||
# of the query's terms. More matches result in higher
|
||||
# scores.
|
||||
"match": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
@@ -788,18 +811,21 @@ class DocumentQuery:
|
||||
}
|
||||
},
|
||||
{
|
||||
# Matches an exact phrase in a specified order.
|
||||
"match_phrase": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
# The number of words permitted between words of
|
||||
# a query phrase and still result in a match.
|
||||
"slop": 1,
|
||||
"boost": 1.5,
|
||||
}
|
||||
}
|
||||
},
|
||||
],
|
||||
# Ensure at least one term from the query is present in the
|
||||
# document. This defaults to 1, unless a filter or must clause
|
||||
# is supplied, in which case it defaults to 0.
|
||||
# Ensures at least one match subquery from the query is present
|
||||
# in the document. This defaults to 1, unless a filter or must
|
||||
# clause is supplied, in which case it defaults to 0.
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
@@ -833,7 +859,14 @@ class DocumentQuery:
|
||||
The "filter" key applies a logical AND operator to its elements, so
|
||||
every subfilter must evaluate to true in order for the document to be
|
||||
retrieved. This function returns a list of such subfilters.
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/.
|
||||
|
||||
TODO(ENG-3874): The terms queries returned by this function can be made
|
||||
more performant for large cardinality sets by sorting the values by
|
||||
their UTF-8 byte order.
|
||||
|
||||
TODO(ENG-3875): This function can take even better advantage of filter
|
||||
caching by grouping "static" filters together into one sub-clause.
|
||||
|
||||
Args:
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
@@ -878,6 +911,14 @@ class DocumentQuery:
|
||||
the assistant. Matches chunks where ancestor_hierarchy_node_ids
|
||||
contains any of these values.
|
||||
|
||||
Raises:
|
||||
ValueError: document_id and attached_document_ids were supplied
|
||||
together. This is not allowed because they operate on the same
|
||||
schema field, and it does not semantically make sense to use
|
||||
them together.
|
||||
ValueError: Too many of one of the collection arguments was
|
||||
supplied.
|
||||
|
||||
Returns:
|
||||
A list of filters to be passed into the "filter" key of a search
|
||||
query.
|
||||
@@ -885,61 +926,156 @@ class DocumentQuery:
|
||||
|
||||
def _get_acl_visibility_filter(
|
||||
access_control_list: list[str],
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, dict[str, list[TermQuery[bool] | TermsQuery[str]] | int]]:
|
||||
"""Returns a filter for the access control list.
|
||||
|
||||
Since this returns an isolated bool should clause, it can be cached
|
||||
in OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
access_control_list: The access control list to restrict
|
||||
documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of access control list entries is greater
|
||||
than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
|
||||
Returns:
|
||||
A filter for the access control list.
|
||||
"""
|
||||
# Logical OR operator on its elements.
|
||||
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
acl_visibility_filter["bool"]["should"].append(
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
|
||||
)
|
||||
for acl in access_control_list:
|
||||
acl_subclause: dict[str, Any] = {
|
||||
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
|
||||
acl_visibility_filter: dict[str, dict[str, Any]] = {
|
||||
"bool": {
|
||||
"should": [{"term": {PUBLIC_FIELD_NAME: {"value": True}}}],
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
if access_control_list:
|
||||
if len(access_control_list) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many access control list entries: {len(access_control_list)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause
|
||||
# because Lucene will optimize the filtering for large sets of
|
||||
# terms. Small sets of terms are not expected to perform any
|
||||
# differently than individual term clauses.
|
||||
acl_subclause: TermsQuery[str] = {
|
||||
"terms": {ACCESS_CONTROL_LIST_FIELD_NAME: list(access_control_list)}
|
||||
}
|
||||
acl_visibility_filter["bool"]["should"].append(acl_subclause)
|
||||
return acl_visibility_filter
|
||||
|
||||
def _get_source_type_filter(
|
||||
source_types: list[DocumentSource],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for source_type in source_types:
|
||||
source_type_filter["bool"]["should"].append(
|
||||
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
|
||||
) -> TermsQuery[str]:
|
||||
"""Returns a filter for the source types.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
source_types: The source types to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of source types is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the source types.
|
||||
"""
|
||||
if not source_types:
|
||||
raise ValueError(
|
||||
"source_types cannot be empty if trying to create a source type filter."
|
||||
)
|
||||
return source_type_filter
|
||||
|
||||
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
tag_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for tag in tags:
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
|
||||
tag_filter["bool"]["should"].append(
|
||||
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
|
||||
if len(source_types) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many source types: {len(source_types)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
return tag_filter
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {
|
||||
"terms": {
|
||||
SOURCE_TYPE_FIELD_NAME: [
|
||||
source_type.value for source_type in source_types
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for document_set in document_sets:
|
||||
document_set_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
|
||||
def _get_tag_filter(tags: list[Tag]) -> TermsQuery[str]:
|
||||
"""Returns a filter for the tags.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
tags: The tags to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of tags is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the tags.
|
||||
"""
|
||||
if not tags:
|
||||
raise ValueError(
|
||||
"tags cannot be empty if trying to create a tag filter."
|
||||
)
|
||||
return document_set_filter
|
||||
if len(tags) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many tags: {len(tags)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str_list = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in tags
|
||||
]
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {METADATA_LIST_FIELD_NAME: tag_str_list}}
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
user_project_filter["bool"]["should"].append(
|
||||
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
)
|
||||
return user_project_filter
|
||||
def _get_document_set_filter(document_sets: list[str]) -> TermsQuery[str]:
|
||||
"""Returns a filter for the document sets.
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
document_sets: The document sets to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of document sets is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the document sets.
|
||||
"""
|
||||
if not document_sets:
|
||||
raise ValueError(
|
||||
"document_sets cannot be empty if trying to create a document set filter."
|
||||
)
|
||||
if len(document_sets) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many document sets: {len(document_sets)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {DOCUMENT_SETS_FIELD_NAME: list(document_sets)}}
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> TermQuery[int]:
|
||||
return {"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> TermQuery[int]:
|
||||
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
@@ -947,7 +1083,9 @@ class DocumentQuery:
|
||||
# document data.
|
||||
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
|
||||
# Logical OR operator on its elements.
|
||||
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
time_cutoff_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"range": {
|
||||
@@ -982,25 +1120,77 @@ class DocumentQuery:
|
||||
|
||||
def _get_attached_document_id_filter(
|
||||
doc_ids: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""Filter for documents explicitly attached to an assistant."""
|
||||
# Logical OR operator on its elements.
|
||||
doc_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for doc_id in doc_ids:
|
||||
doc_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": doc_id}}}
|
||||
) -> TermsQuery[str]:
|
||||
"""
|
||||
Returns a filter for documents explicitly attached to an assistant.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
doc_ids: The document IDs to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of document IDs is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the document IDs.
|
||||
"""
|
||||
if not doc_ids:
|
||||
raise ValueError(
|
||||
"doc_ids cannot be empty if trying to create a document ID filter."
|
||||
)
|
||||
return doc_id_filter
|
||||
if len(doc_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many document IDs: {len(doc_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {DOCUMENT_ID_FIELD_NAME: list(doc_ids)}}
|
||||
|
||||
def _get_hierarchy_node_filter(
|
||||
node_ids: list[int],
|
||||
) -> dict[str, Any]:
|
||||
"""Filter for chunks whose ancestors include any of the given hierarchy nodes.
|
||||
|
||||
Uses a terms query to check if ancestor_hierarchy_node_ids contains
|
||||
any of the specified node IDs.
|
||||
) -> TermsQuery[int]:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
Returns a filter for chunks whose ancestors include any of the given
|
||||
hierarchy nodes.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
node_ids: The hierarchy node IDs to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of hierarchy node IDs is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the hierarchy node IDs.
|
||||
"""
|
||||
if not node_ids:
|
||||
raise ValueError(
|
||||
"node_ids cannot be empty if trying to create a hierarchy node ID filter."
|
||||
)
|
||||
if len(node_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many hierarchy node IDs: {len(node_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: list(node_ids)}}
|
||||
|
||||
if document_id is not None and attached_document_ids is not None:
|
||||
raise ValueError(
|
||||
"document_id and attached_document_ids cannot be used together."
|
||||
)
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
@@ -1045,6 +1235,9 @@ class DocumentQuery:
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
# Since this returns an isolated bool should clause, it can be
|
||||
# cached in OpenSearch independently of other clauses in
|
||||
# _get_search_filters.
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import time
|
||||
import urllib
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -318,7 +320,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
|
||||
@@ -338,22 +340,31 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
# Vespa has restrictions on valid characters, yet document IDs come from
|
||||
# external w.r.t. this class. We need to sanitize them.
|
||||
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
|
||||
clean_chunk_id_copy(chunk) for chunk in chunks
|
||||
]
|
||||
assert len(cleaned_chunks) == len(
|
||||
chunks
|
||||
), "Bug: Cleaned chunks and input chunks have different lengths."
|
||||
#
|
||||
# Instead of materializing all cleaned chunks upfront, we stream them
|
||||
# through a generator that cleans IDs and builds the original-ID mapping
|
||||
# incrementally as chunks flow into Vespa.
|
||||
def _clean_and_track(
|
||||
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
|
||||
id_map: dict[str, str],
|
||||
seen_ids: set[str],
|
||||
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
|
||||
"""Cleans chunk IDs and builds the original-ID mapping
|
||||
incrementally as chunks flow through, avoiding a separate
|
||||
materialization pass."""
|
||||
for chunk in chunks_iter:
|
||||
original_id = chunk.source_document.id
|
||||
cleaned = clean_chunk_id_copy(chunk)
|
||||
cleaned_id = cleaned.source_document.id
|
||||
# Needed so the final DocumentInsertionRecord returned can have
|
||||
# the original document ID. cleaned_chunks might not contain IDs
|
||||
# exactly as callers supplied them.
|
||||
id_map[cleaned_id] = original_id
|
||||
seen_ids.add(cleaned_id)
|
||||
yield cleaned
|
||||
|
||||
# Needed so the final DocumentInsertionRecord returned can have the
|
||||
# original document ID. cleaned_chunks might not contain IDs exactly as
|
||||
# callers supplied them.
|
||||
new_document_id_to_original_document_id: dict[str, str] = dict()
|
||||
for i, cleaned_chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[i]
|
||||
new_document_id_to_original_document_id[
|
||||
cleaned_chunk.source_document.id
|
||||
] = old_chunk.source_document.id
|
||||
new_document_id_to_original_document_id: dict[str, str] = {}
|
||||
all_cleaned_doc_ids: set[str] = set()
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
@@ -409,7 +420,13 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Insert new Vespa documents.
|
||||
# Insert new Vespa documents, streaming through the cleaning
|
||||
# pipeline so chunks are never fully materialized.
|
||||
cleaned_chunks = _clean_and_track(
|
||||
chunks,
|
||||
new_document_id_to_original_document_id,
|
||||
all_cleaned_doc_ids,
|
||||
)
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
@@ -419,10 +436,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids: set[str] = {
|
||||
chunk.source_document.id for chunk in cleaned_chunks
|
||||
}
|
||||
|
||||
return [
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
@@ -610,6 +623,22 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
return cleanup_content_for_chunks(query_vespa(params))
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
|
||||
@@ -15,7 +15,7 @@ Usage (Celery tasks and FastAPI handlers):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (spec.response_model)
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
|
||||
@@ -91,6 +91,8 @@ class HookResponse(BaseModel):
|
||||
# Nullable to match the DB column — endpoint_url is required on creation but
|
||||
# future hook point types may not use an external endpoint (e.g. built-in handlers).
|
||||
endpoint_url: str | None
|
||||
# Partially-masked API key (e.g. "abcd••••••••wxyz"), or None if no key is set.
|
||||
api_key_masked: str | None
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
|
||||
@@ -26,6 +26,8 @@ class DocumentIngestionSpec(HookPointSpec):
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
|
||||
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
|
||||
@@ -65,6 +65,8 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
|
||||
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -173,8 +174,10 @@ class UserFileIndexingAdapter:
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
token_count: int = (
|
||||
count_tokens(combined_content, llm_tokenizer)
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[str(user_file_id)] = token_count
|
||||
else:
|
||||
|
||||
@@ -25,6 +25,7 @@ class LlmProviderNames(str, Enum):
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -44,6 +45,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
]
|
||||
|
||||
|
||||
@@ -61,6 +63,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -112,6 +115,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -290,6 +290,17 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
self._custom_llm_provider = "openai"
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
model_kwargs["api_base"] = self._api_base
|
||||
|
||||
# This is needed for Ollama to do proper function calling
|
||||
if model_provider == LlmProviderNames.OLLAMA_CHAT and api_base is not None:
|
||||
model_kwargs["api_base"] = api_base
|
||||
@@ -401,14 +412,20 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
model = (
|
||||
f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
)
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
|
||||
# Tool choice
|
||||
if is_claude_model and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
@@ -483,10 +500,11 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if not (is_claude_model or is_ollama or is_mistral):
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included.
|
||||
# so it must be conditionally included unless the request is
|
||||
# routed through Bifrost's OpenAI-compatible endpoint.
|
||||
# Additionally, tool_choice is not supported by Ollama and causes warnings if included.
|
||||
# See also, https://github.com/ollama/ollama/issues/11171
|
||||
optional_kwargs["allowed_openai_params"] = ["tool_choice"]
|
||||
|
||||
@@ -11,6 +11,7 @@ class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
display_name: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@@ -13,6 +13,8 @@ LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
BIFROST_PROVIDER_NAME = "bifrost"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
@@ -49,6 +50,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -175,6 +175,32 @@ def get_tokenizer(
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
|
||||
|
||||
# Max characters per encode() call.
|
||||
_ENCODE_CHUNK_SIZE = 500_000
|
||||
|
||||
|
||||
def count_tokens(
|
||||
text: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
token_limit: int | None = None,
|
||||
) -> int:
|
||||
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
|
||||
|
||||
If token_limit is provided and the text is large enough to require
|
||||
multiple chunks (> 500k chars), stops early once the count exceeds it.
|
||||
When early-exiting, the returned value exceeds token_limit but may be
|
||||
less than the true full token count.
|
||||
"""
|
||||
if len(text) <= _ENCODE_CHUNK_SIZE:
|
||||
return len(tokenizer.encode(text))
|
||||
total = 0
|
||||
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
|
||||
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
|
||||
if token_limit is not None and total > token_limit:
|
||||
return total # Already over — skip remaining chunks
|
||||
return total
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
|
||||
@@ -690,9 +690,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@dotenvx/dotenvx/node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
@@ -3844,9 +3844,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@ts-morph/common/node_modules/brace-expansion": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.3.tgz",
|
||||
"integrity": "sha512-fy6KJm2RawA5RcHkLa1z/ScpBeA762UF9KmZQxwIbDtRJrgLzM10depAiEQ+CXYcoiqW1/m96OAAoke2nE9EeA==",
|
||||
"version": "5.0.5",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
|
||||
"integrity": "sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^4.0.2"
|
||||
@@ -4224,9 +4224,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz",
|
||||
"integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -5007,9 +5007,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "1.1.12",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
|
||||
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
|
||||
"version": "1.1.13",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz",
|
||||
"integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -9537,9 +9537,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
|
||||
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
@@ -11118,9 +11118,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/tinyglobby/node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
@@ -44,11 +44,12 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -62,6 +63,9 @@ def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookRespo
|
||||
name=hook.name,
|
||||
hook_point=hook.hook_point,
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key_masked=(
|
||||
hook.api_key.get_value(apply_mask=True) if hook.api_key else None
|
||||
),
|
||||
fail_strategy=hook.fail_strategy,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
is_active=hook.is_active,
|
||||
@@ -119,9 +123,8 @@ def _validate_endpoint(
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
|
||||
timeout (operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
@@ -138,19 +141,11 @@ def _validate_endpoint(
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
# Any timeout (connect, read, or write) means the configured timeout_seconds
|
||||
# is too low for this endpoint. Report as timeout so the UI directs the user
|
||||
# to increase the timeout setting.
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
"Hook endpoint validation: timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
@@ -220,8 +215,8 @@ def create_hook(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Create a new hook. The endpoint is validated before persisting — creation fails if
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
|
||||
use POST /{hook_id}/activate once ready to receive traffic."""
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created active.
|
||||
"""
|
||||
spec = get_hook_point_spec(req.hook_point)
|
||||
api_key = req.api_key.get_secret_value() if req.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
@@ -240,9 +235,10 @@ def create_hook(
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
is_active=True,
|
||||
is_reachable=True,
|
||||
creator_id=user.id,
|
||||
)
|
||||
hook.is_reachable = True
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook, creator_email=user.email)
|
||||
|
||||
|
||||
@@ -9,20 +9,15 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -161,8 +156,8 @@ def categorize_uploaded_files(
|
||||
document formats (.pdf, .docx, …) and falls back to a text-detection
|
||||
heuristic for unknown extensions (.py, .js, .rs, …).
|
||||
- Uses default tokenizer to compute token length.
|
||||
- If token length > threshold, reject file (unless threshold skip is enabled).
|
||||
- If text cannot be extracted, reject file.
|
||||
- If token length exceeds the admin-configured threshold, reject file.
|
||||
- If extension unsupported or text cannot be extracted, reject file.
|
||||
- Otherwise marked as acceptable.
|
||||
"""
|
||||
|
||||
@@ -173,36 +168,33 @@ def categorize_uploaded_files(
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
# Check global skip flag (works for both single-tenant and multi-tenant)
|
||||
if SKIP_USERFILE_THRESHOLD:
|
||||
skip_threshold = True
|
||||
logger.info("Skipping userfile threshold check (global setting)")
|
||||
# Check tenant-specific skip list (only applicable in multi-tenant)
|
||||
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
|
||||
try:
|
||||
current_tenant_id = get_current_tenant_id()
|
||||
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
if skip_threshold:
|
||||
logger.info(
|
||||
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to get current tenant ID: {str(e)}")
|
||||
# Derive limits from admin-configurable settings.
|
||||
# For upload size: load_settings() resolves 0/None to a positive default.
|
||||
# For token threshold: 0 means "no limit" (converted to None below).
|
||||
settings = load_settings()
|
||||
max_upload_size_mb = (
|
||||
settings.user_file_max_upload_size_mb
|
||||
) # always positive after load_settings()
|
||||
max_upload_size_bytes = (
|
||||
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
|
||||
)
|
||||
token_threshold_k = settings.file_token_count_threshold_k
|
||||
token_threshold = (
|
||||
token_threshold_k * 1000 if token_threshold_k else None
|
||||
) # 0 → None = no limit
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
# Size limit is a hard safety cap.
|
||||
if max_upload_size_bytes is not None and is_upload_too_large(
|
||||
upload, max_upload_size_bytes
|
||||
):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
@@ -224,11 +216,11 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -269,12 +261,14 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
token_count = len(tokenizer.encode(text_content))
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
token_count = count_tokens(
|
||||
text_content, tokenizer, token_limit=token_threshold
|
||||
)
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -57,6 +57,8 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import BifrostFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BifrostModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
@@ -1422,11 +1424,26 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="LiteLLM proxy",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_openai_compatible_models_response(
|
||||
url: str,
|
||||
source_name: str,
|
||||
api_key: str | None = None,
|
||||
) -> dict:
|
||||
"""Fetch model metadata from an OpenAI-compatible `/models` endpoint."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if not api_key:
|
||||
headers.pop("Authorization")
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
@@ -1436,20 +1453,125 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
f"Authentication failed: invalid or missing API key for {source_name}.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"LiteLLM models endpoint not found at {url}. Please verify the API base URL.",
|
||||
f"{source_name} models endpoint not found at {url}. Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
except httpx.RequestError as e:
|
||||
logger.warning(
|
||||
"Failed to fetch models from OpenAI-compatible endpoint",
|
||||
extra={"source": source_name, "url": url, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Received invalid model response from OpenAI-compatible endpoint",
|
||||
extra={"source": source_name, "url": url, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/bifrost/available-models")
|
||||
def get_bifrost_available_models(
|
||||
request: BifrostModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Bifrost endpoint",
|
||||
)
|
||||
|
||||
results: list[BifrostFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
BifrostFinalModelResponse(
|
||||
name=model_id,
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Bifrost model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Bifrost",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Bifrost",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> dict:
|
||||
"""Perform GET to Bifrost /v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
# Ensure we hit /v1/models
|
||||
if cleaned_api_base.endswith("/v1"):
|
||||
url = f"{cleaned_api_base}/models"
|
||||
else:
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="Bifrost",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
@@ -449,3 +449,18 @@ class LitellmModelDetails(BaseModel):
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
|
||||
# Bifrost dynamic models fetch
|
||||
class BifrostModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class BifrostFinalModelResponse(BaseModel):
|
||||
name: str # Model ID in provider/model format (e.g. "anthropic/claude-sonnet-4-6")
|
||||
display_name: str # Human-readable name from Bifrost API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -25,6 +25,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.BIFROST,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -50,6 +51,25 @@ BEDROCK_VISION_MODELS = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
# Known Bifrost/OpenAI-compatible vision-capable model families where the
|
||||
# source API does not expose this metadata directly.
|
||||
BIFROST_VISION_MODEL_FAMILIES = frozenset(
|
||||
{
|
||||
"anthropic/claude-3",
|
||||
"anthropic/claude-4",
|
||||
"amazon/nova-pro",
|
||||
"amazon/nova-lite",
|
||||
"amazon/nova-premier",
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4.1",
|
||||
"google/gemini",
|
||||
"meta-llama/llama-3.2",
|
||||
"mistral/pixtral",
|
||||
"qwen/qwen2.5-vl",
|
||||
"qwen/qwen-vl",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_valid_bedrock_model(
|
||||
model_id: str,
|
||||
@@ -76,11 +96,18 @@ def is_valid_bedrock_model(
|
||||
def infer_vision_support(model_id: str) -> bool:
|
||||
"""Infer vision support from model ID when base model metadata unavailable.
|
||||
|
||||
Used for cross-region inference profiles when the base model isn't
|
||||
available in the user's region.
|
||||
Used for providers like Bedrock and Bifrost where vision support may
|
||||
need to be inferred from vendor/model naming conventions.
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
return any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS)
|
||||
if any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS):
|
||||
return True
|
||||
|
||||
normalized_model_id = model_id_lower.replace(".", "/")
|
||||
return any(
|
||||
vision_model in normalized_model_id
|
||||
for vision_model in BIFROST_VISION_MODEL_FAMILIES
|
||||
)
|
||||
|
||||
|
||||
def generate_bedrock_display_name(model_id: str) -> str:
|
||||
@@ -322,7 +349,7 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
- Ollama: "llama3:70b" → "Meta"
|
||||
- Ollama: "qwen2.5:7b" → "Alibaba"
|
||||
"""
|
||||
if provider == LlmProviderNames.OPENROUTER:
|
||||
if provider in (LlmProviderNames.OPENROUTER, LlmProviderNames.BIFROST):
|
||||
# Format: "vendor/model-name" e.g., "anthropic/claude-3-5-sonnet"
|
||||
if "/" in model_name:
|
||||
vendor_key = model_name.split("/")[0].lower()
|
||||
|
||||
207
backend/onyx/server/metrics/celery_task_metrics.py
Normal file
207
backend/onyx/server/metrics/celery_task_metrics.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
Usage in a worker app module:
|
||||
from onyx.server.metrics.celery_task_metrics import (
|
||||
on_celery_task_prerun,
|
||||
on_celery_task_postrun,
|
||||
on_celery_task_retry,
|
||||
on_celery_task_revoked,
|
||||
on_celery_task_rejected,
|
||||
)
|
||||
# Call from the worker's existing signal handlers
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from celery import Task
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
TASK_STARTED = Counter(
|
||||
"onyx_celery_task_started_total",
|
||||
"Total Celery tasks started",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_COMPLETED = Counter(
|
||||
"onyx_celery_task_completed_total",
|
||||
"Total Celery tasks completed",
|
||||
["task_name", "queue", "outcome"],
|
||||
)
|
||||
|
||||
TASK_DURATION = Histogram(
|
||||
"onyx_celery_task_duration_seconds",
|
||||
"Celery task execution duration in seconds",
|
||||
["task_name", "queue"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
TASKS_ACTIVE = Gauge(
|
||||
"onyx_celery_tasks_active",
|
||||
"Currently executing Celery tasks",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_RETRIED = Counter(
|
||||
"onyx_celery_task_retried_total",
|
||||
"Total Celery tasks retried",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_REVOKED = Counter(
|
||||
"onyx_celery_task_revoked_total",
|
||||
"Total Celery tasks revoked (cancelled)",
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_REJECTED = Counter(
|
||||
"onyx_celery_task_rejected_total",
|
||||
"Total Celery tasks rejected by worker",
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
# Lock protecting _task_start_times — prerun, postrun, and eviction may
|
||||
# run concurrently on thread-pool workers.
|
||||
_task_start_times_lock = threading.Lock()
|
||||
|
||||
# Entries older than this are evicted on each prerun to prevent unbounded
|
||||
# growth when tasks are killed (SIGTERM, OOM) and postrun never fires.
|
||||
_MAX_START_TIME_AGE_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
def _evict_stale_start_times() -> None:
|
||||
"""Remove _task_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
|
||||
|
||||
Must be called while holding _task_start_times_lock.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
stale_ids = [
|
||||
tid
|
||||
for tid, (start, _labels) in _task_start_times.items()
|
||||
if now - start > _MAX_START_TIME_AGE_SECONDS
|
||||
]
|
||||
for tid in stale_ids:
|
||||
entry = _task_start_times.pop(tid, None)
|
||||
if entry is not None:
|
||||
_labels = entry[1]
|
||||
# Decrement active gauge for evicted tasks — these tasks were
|
||||
# started but never completed (killed, OOM, etc.).
|
||||
active_gauge = TASKS_ACTIVE.labels(**_labels)
|
||||
if active_gauge._value.get() > 0:
|
||||
active_gauge.dec()
|
||||
|
||||
|
||||
def _get_task_labels(task: Task) -> dict[str, str]:
|
||||
"""Extract task_name and queue labels from a Celery Task instance."""
|
||||
task_name = task.name or "unknown"
|
||||
queue = "unknown"
|
||||
try:
|
||||
delivery_info = task.request.delivery_info
|
||||
if delivery_info:
|
||||
queue = delivery_info.get("routing_key") or "unknown"
|
||||
except AttributeError:
|
||||
pass
|
||||
return {"task_name": task_name, "queue": queue}
|
||||
|
||||
|
||||
def on_celery_task_prerun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
) -> None:
|
||||
"""Record task start. Call from the worker's task_prerun signal handler."""
|
||||
if task is None or task_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
TASK_STARTED.labels(**labels).inc()
|
||||
TASKS_ACTIVE.labels(**labels).inc()
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_postrun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
state: str | None,
|
||||
) -> None:
|
||||
"""Record task completion. Call from the worker's task_postrun signal handler."""
|
||||
if task is None or task_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
outcome = "success" if state == "SUCCESS" else "failure"
|
||||
TASK_COMPLETED.labels(**labels, outcome=outcome).inc()
|
||||
|
||||
# Guard against going below 0 if postrun fires without a matching
|
||||
# prerun (e.g. after a worker restart or stale entry eviction).
|
||||
active_gauge = TASKS_ACTIVE.labels(**labels)
|
||||
if active_gauge._value.get() > 0:
|
||||
active_gauge.dec()
|
||||
|
||||
with _task_start_times_lock:
|
||||
entry = _task_start_times.pop(task_id, None)
|
||||
if entry is not None:
|
||||
start_time, _stored_labels = entry
|
||||
TASK_DURATION.labels(**labels).observe(time.monotonic() - start_time)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task postrun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_retry(
|
||||
_task_id: str | None,
|
||||
task: Task | None,
|
||||
) -> None:
|
||||
"""Record task retry. Call from the worker's task_retry signal handler."""
|
||||
if task is None:
|
||||
return
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
TASK_RETRIED.labels(**labels).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task retry metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_revoked(
|
||||
_task_id: str | None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record task revocation. The revoked signal doesn't provide a Task
|
||||
instance, only the task name via sender."""
|
||||
if task_name is None:
|
||||
return
|
||||
try:
|
||||
TASK_REVOKED.labels(task_name=task_name).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task revoked metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_rejected(
|
||||
_task_id: str | None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record task rejection."""
|
||||
if task_name is None:
|
||||
return
|
||||
try:
|
||||
TASK_REJECTED.labels(task_name=task_name).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task rejected metrics", exc_info=True)
|
||||
593
backend/onyx/server/metrics/indexing_pipeline.py
Normal file
593
backend/onyx/server/metrics/indexing_pipeline.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
|
||||
|
||||
These collectors query Redis and Postgres at scrape time (the Collector pattern),
|
||||
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
|
||||
monitoring celery worker which already has Redis and DB access.
|
||||
|
||||
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
|
||||
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
|
||||
stale, which is fine for monitoring dashboards.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.registry import Collector
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default cache TTL in seconds. Scrapes hitting within this window return
|
||||
# the previous result without re-querying Redis/Postgres.
|
||||
_DEFAULT_CACHE_TTL = 30.0
|
||||
|
||||
_QUEUE_LABEL_MAP: dict[str, str] = {
|
||||
OnyxCeleryQueues.PRIMARY: "primary",
|
||||
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING: "docfetching",
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC: "vespa_metadata_sync",
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION: "connector_deletion",
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING: "connector_pruning",
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC: "permissions_sync",
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC: "external_group_sync",
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT: "permissions_upsert",
|
||||
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING: "hierarchy_fetching",
|
||||
OnyxCeleryQueues.LLM_MODEL_UPDATE: "llm_model_update",
|
||||
OnyxCeleryQueues.CHECKPOINT_CLEANUP: "checkpoint_cleanup",
|
||||
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP: "index_attempt_cleanup",
|
||||
OnyxCeleryQueues.CSV_GENERATION: "csv_generation",
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING: "user_file_processing",
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC: "user_file_project_sync",
|
||||
OnyxCeleryQueues.USER_FILE_DELETE: "user_file_delete",
|
||||
OnyxCeleryQueues.MONITORING: "monitoring",
|
||||
OnyxCeleryQueues.SANDBOX: "sandbox",
|
||||
OnyxCeleryQueues.OPENSEARCH_MIGRATION: "opensearch_migration",
|
||||
}
|
||||
|
||||
# Queues where prefetched (unacked) task counts are meaningful
|
||||
_UNACKED_QUEUES: list[str] = [
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
OnyxCeleryQueues.DOCPROCESSING,
|
||||
]
|
||||
|
||||
|
||||
class _CachedCollector(Collector):
|
||||
"""Base collector with TTL-based caching.
|
||||
|
||||
Subclasses implement ``_collect_fresh()`` to query the actual data source.
|
||||
The base ``collect()`` returns cached results if the TTL hasn't expired,
|
||||
avoiding repeated queries when Prometheus scrapes frequently.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cached_result: list[GaugeMetricFamily] | None = None
|
||||
self._last_collect_time: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def collect(self) -> list[GaugeMetricFamily]:
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
if (
|
||||
now - self._last_collect_time < self._cache_ttl
|
||||
and self._cached_result is not None
|
||||
):
|
||||
return self._cached_result
|
||||
|
||||
try:
|
||||
result = self._collect_fresh()
|
||||
self._cached_result = result
|
||||
self._last_collect_time = now
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception(f"Error in {type(self).__name__}.collect()")
|
||||
# Return stale cache on error rather than nothing — avoids
|
||||
# metrics disappearing during transient failures.
|
||||
return self._cached_result if self._cached_result is not None else []
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
raise NotImplementedError
|
||||
|
||||
def describe(self) -> list[GaugeMetricFamily]:
|
||||
return []
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
"Number of tasks waiting in Celery queue",
|
||||
labels=["queue"],
|
||||
)
|
||||
unacked = GaugeMetricFamily(
|
||||
"onyx_queue_unacked",
|
||||
"Number of prefetched (unacked) tasks for queue",
|
||||
labels=["queue"],
|
||||
)
|
||||
queue_age = GaugeMetricFamily(
|
||||
"onyx_queue_oldest_task_age_seconds",
|
||||
"Age of the oldest task in the queue (seconds since enqueue)",
|
||||
labels=["queue"],
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
for queue_name, label in _QUEUE_LABEL_MAP.items():
|
||||
length = celery_get_queue_length(queue_name, redis_client)
|
||||
depth.add_metric([label], length)
|
||||
|
||||
# Peek at the oldest message to get its age
|
||||
if length > 0:
|
||||
age = self._get_oldest_message_age(redis_client, queue_name, now)
|
||||
if age is not None:
|
||||
queue_age.add_metric([label], age)
|
||||
|
||||
for queue_name in _UNACKED_QUEUES:
|
||||
label = _QUEUE_LABEL_MAP[queue_name]
|
||||
task_ids = celery_get_unacked_task_ids(queue_name, redis_client)
|
||||
unacked.add_metric([label], len(task_ids))
|
||||
|
||||
return [depth, unacked, queue_age]
|
||||
|
||||
@staticmethod
|
||||
def _get_oldest_message_age(
|
||||
redis_client: Redis, queue_name: str, now: float
|
||||
) -> float | None:
|
||||
"""Peek at the oldest (tail) message in a Redis list queue
|
||||
and extract its timestamp to compute age.
|
||||
|
||||
Note: If the Celery message contains neither ``properties.timestamp``
|
||||
nor ``headers.timestamp``, no age metric is emitted for this queue.
|
||||
This can happen with custom task producers or non-standard Celery
|
||||
protocol versions. The metric will simply be absent rather than
|
||||
inaccurate, which is the safest behavior for alerting.
|
||||
"""
|
||||
try:
|
||||
raw: bytes | str | None = redis_client.lindex(queue_name, -1) # type: ignore[assignment]
|
||||
if raw is None:
|
||||
return None
|
||||
msg = json.loads(raw)
|
||||
# Check for ETA tasks first — they are intentionally delayed,
|
||||
# so reporting their queue age would be misleading.
|
||||
headers = msg.get("headers", {})
|
||||
if headers.get("eta") is not None:
|
||||
return None
|
||||
# Celery v2 protocol: timestamp in properties
|
||||
props = msg.get("properties", {})
|
||||
ts = props.get("timestamp")
|
||||
if ts is not None:
|
||||
return now - float(ts)
|
||||
# Fallback: some Celery configurations place the timestamp in
|
||||
# headers instead of properties.
|
||||
ts = headers.get("timestamp")
|
||||
if ts is not None:
|
||||
return now - float(ts)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class IndexAttemptCollector(_CachedCollector):
|
||||
"""Queries Postgres for index attempt state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
self._terminal_statuses: list = []
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
attempts_gauge = GaugeMetricFamily(
|
||||
"onyx_index_attempts_active",
|
||||
"Number of non-terminal index attempts",
|
||||
labels=[
|
||||
"status",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"connector_name",
|
||||
"cc_pair_id",
|
||||
],
|
||||
)
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
rows = get_active_index_attempts_for_metrics(session)
|
||||
|
||||
for status, source, cc_id, cc_name, count in rows:
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
attempts_gauge.add_metric(
|
||||
[
|
||||
status.value,
|
||||
source.value,
|
||||
tid,
|
||||
name_val,
|
||||
str(cc_id),
|
||||
],
|
||||
count,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [attempts_gauge]
|
||||
|
||||
|
||||
class ConnectorHealthCollector(_CachedCollector):
|
||||
"""Queries Postgres for connector health state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_health_for_metrics,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
|
||||
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
staleness_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"Seconds since last successful index for this connector",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
error_state_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
by_status_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_by_status",
|
||||
"Number of connectors grouped by status",
|
||||
labels=["tenant_id", "status"],
|
||||
)
|
||||
error_total_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_in_error_total",
|
||||
"Total number of connectors in repeated error state",
|
||||
labels=["tenant_id"],
|
||||
)
|
||||
per_connector_labels = [
|
||||
"tenant_id",
|
||||
"source",
|
||||
"cc_pair_id",
|
||||
"connector_name",
|
||||
]
|
||||
docs_success_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_docs_indexed",
|
||||
"Total new documents indexed (90-day rolling sum) per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
docs_error_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_error_count",
|
||||
"Total number of failed index attempts per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
pairs = get_connector_health_for_metrics(session)
|
||||
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
|
||||
docs_by_cc = get_docs_indexed_by_cc_pair(session)
|
||||
|
||||
status_counts: dict[str, int] = {}
|
||||
error_count = 0
|
||||
|
||||
for (
|
||||
cc_id,
|
||||
status,
|
||||
in_error,
|
||||
last_success,
|
||||
cc_name,
|
||||
source,
|
||||
) in pairs:
|
||||
cc_id_str = str(cc_id)
|
||||
source_val = source.value
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
label_vals = [tid, source_val, cc_id_str, name_val]
|
||||
|
||||
if last_success is not None:
|
||||
# Both `now` and `last_success` are timezone-aware
|
||||
# (the DB column uses DateTime(timezone=True)),
|
||||
# so subtraction is safe.
|
||||
age = (now - last_success).total_seconds()
|
||||
staleness_gauge.add_metric(label_vals, age)
|
||||
|
||||
error_state_gauge.add_metric(
|
||||
label_vals,
|
||||
1.0 if in_error else 0.0,
|
||||
)
|
||||
if in_error:
|
||||
error_count += 1
|
||||
|
||||
docs_success_gauge.add_metric(
|
||||
label_vals,
|
||||
docs_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
docs_error_gauge.add_metric(
|
||||
label_vals,
|
||||
error_counts_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
status_val = status.value
|
||||
status_counts[status_val] = status_counts.get(status_val, 0) + 1
|
||||
|
||||
for status_val, count in status_counts.items():
|
||||
by_status_gauge.add_metric([tid, status_val], count)
|
||||
|
||||
error_total_gauge.add_metric([tid], error_count)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [
|
||||
staleness_gauge,
|
||||
error_state_gauge,
|
||||
by_status_gauge,
|
||||
error_total_gauge,
|
||||
docs_success_gauge,
|
||||
docs_error_gauge,
|
||||
]
|
||||
|
||||
|
||||
class RedisHealthCollector(_CachedCollector):
|
||||
"""Collects Redis server health metrics (memory, clients, etc.)."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
"Redis used memory in bytes",
|
||||
)
|
||||
memory_peak = GaugeMetricFamily(
|
||||
"onyx_redis_memory_peak_bytes",
|
||||
"Redis peak used memory in bytes",
|
||||
)
|
||||
memory_frag = GaugeMetricFamily(
|
||||
"onyx_redis_memory_fragmentation_ratio",
|
||||
"Redis memory fragmentation ratio (>1.5 indicates fragmentation)",
|
||||
)
|
||||
connected_clients = GaugeMetricFamily(
|
||||
"onyx_redis_connected_clients",
|
||||
"Number of connected Redis clients",
|
||||
)
|
||||
|
||||
try:
|
||||
mem_info: dict = redis_client.info("memory") # type: ignore[assignment]
|
||||
memory_used.add_metric([], mem_info.get("used_memory", 0))
|
||||
memory_peak.add_metric([], mem_info.get("used_memory_peak", 0))
|
||||
frag = mem_info.get("mem_fragmentation_ratio")
|
||||
if frag is not None:
|
||||
memory_frag.add_metric([], frag)
|
||||
|
||||
client_info: dict = redis_client.info("clients") # type: ignore[assignment]
|
||||
connected_clients.add_metric([], client_info.get("connected_clients", 0))
|
||||
except Exception:
|
||||
logger.debug("Failed to collect Redis health metrics", exc_info=True)
|
||||
|
||||
return [memory_used, memory_peak, memory_frag, connected_clients]
|
||||
|
||||
|
||||
class WorkerHeartbeatMonitor:
|
||||
"""Monitors Celery worker health via the event stream.
|
||||
|
||||
Subscribes to ``worker-heartbeat``, ``worker-online``, and
|
||||
``worker-offline`` events via a single persistent connection.
|
||||
Runs in a daemon thread started once during worker setup.
|
||||
"""
|
||||
|
||||
# Consider a worker down if no heartbeat received for this long.
|
||||
_HEARTBEAT_TIMEOUT_SECONDS = 120.0
|
||||
|
||||
def __init__(self, celery_app: Any) -> None:
|
||||
self._app = celery_app
|
||||
self._worker_last_seen: dict[str, float] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background event listener thread.
|
||||
|
||||
Safe to call multiple times — only starts one thread.
|
||||
"""
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._listen, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("WorkerHeartbeatMonitor started")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
def _listen(self) -> None:
|
||||
"""Background loop: connect to event stream and process heartbeats."""
|
||||
while self._running:
|
||||
try:
|
||||
with self._app.connection() as conn:
|
||||
recv = self._app.events.Receiver(
|
||||
conn,
|
||||
handlers={
|
||||
"worker-heartbeat": self._on_heartbeat,
|
||||
"worker-online": self._on_heartbeat,
|
||||
"worker-offline": self._on_offline,
|
||||
},
|
||||
)
|
||||
recv.capture(
|
||||
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
|
||||
)
|
||||
except Exception:
|
||||
if self._running:
|
||||
logger.debug(
|
||||
"Heartbeat listener disconnected, reconnecting in 5s",
|
||||
exc_info=True,
|
||||
)
|
||||
time.sleep(5.0)
|
||||
else:
|
||||
# capture() returned normally (timeout with no events); reconnect
|
||||
if self._running:
|
||||
logger.debug("Heartbeat capture timed out, reconnecting")
|
||||
time.sleep(5.0)
|
||||
|
||||
def _on_heartbeat(self, event: dict[str, Any]) -> None:
|
||||
hostname = event.get("hostname")
|
||||
if hostname:
|
||||
with self._lock:
|
||||
self._worker_last_seen[hostname] = time.monotonic()
|
||||
|
||||
def _on_offline(self, event: dict[str, Any]) -> None:
|
||||
hostname = event.get("hostname")
|
||||
if hostname:
|
||||
with self._lock:
|
||||
self._worker_last_seen.pop(hostname, None)
|
||||
|
||||
def get_worker_status(self) -> dict[str, bool]:
|
||||
"""Return {hostname: is_alive} for all known workers.
|
||||
|
||||
Thread-safe. Called by WorkerHealthCollector on each scrape.
|
||||
Also prunes workers that have been dead longer than 2x the
|
||||
heartbeat timeout to prevent unbounded growth.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
prune_threshold = self._HEARTBEAT_TIMEOUT_SECONDS * 2
|
||||
with self._lock:
|
||||
# Prune workers that have been gone for 2x the timeout
|
||||
stale = [
|
||||
h
|
||||
for h, ts in self._worker_last_seen.items()
|
||||
if (now - ts) > prune_threshold
|
||||
]
|
||||
for h in stale:
|
||||
del self._worker_last_seen[h]
|
||||
|
||||
result: dict[str, bool] = {}
|
||||
for hostname, last_seen in self._worker_last_seen.items():
|
||||
alive = (now - last_seen) < self._HEARTBEAT_TIMEOUT_SECONDS
|
||||
result[hostname] = alive
|
||||
return result
|
||||
|
||||
|
||||
class WorkerHealthCollector(_CachedCollector):
|
||||
"""Collects Celery worker health from the heartbeat monitor.
|
||||
|
||||
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
|
||||
to the Celery event stream via a single persistent connection.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = 30.0) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
def set_monitor(self, monitor: WorkerHeartbeatMonitor) -> None:
|
||||
"""Set the heartbeat monitor instance."""
|
||||
self._monitor = monitor
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._monitor is None:
|
||||
return []
|
||||
|
||||
active_workers = GaugeMetricFamily(
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers with recent heartbeats",
|
||||
)
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
"Whether a specific Celery worker is alive (1=up, 0=down)",
|
||||
labels=["worker"],
|
||||
)
|
||||
|
||||
try:
|
||||
status = self._monitor.get_worker_status()
|
||||
alive_count = sum(1 for alive in status.values() if alive)
|
||||
active_workers.add_metric([], alive_count)
|
||||
|
||||
for hostname in sorted(status):
|
||||
# Use short name (before @) for single-host deployments,
|
||||
# full hostname when multiple hosts share a worker type.
|
||||
label = hostname.split("@")[0]
|
||||
worker_up.add_metric([label], 1 if status[hostname] else 0)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
return [active_workers, worker_up]
|
||||
63
backend/onyx/server/metrics/indexing_pipeline_setup.py
Normal file
63
backend/onyx/server/metrics/indexing_pipeline_setup.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Setup function for indexing pipeline Prometheus collectors.
|
||||
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_celery_app() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
_attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
# Module-level singleton prevents duplicate threads on re-entry.
|
||||
global _heartbeat_monitor
|
||||
if _heartbeat_monitor is None:
|
||||
_heartbeat_monitor = WorkerHeartbeatMonitor(celery_app)
|
||||
_heartbeat_monitor.start()
|
||||
_worker_health_collector.set_monitor(_heartbeat_monitor)
|
||||
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
for collector in (
|
||||
_queue_collector,
|
||||
_attempt_collector,
|
||||
_connector_collector,
|
||||
_redis_health_collector,
|
||||
_worker_health_collector,
|
||||
):
|
||||
try:
|
||||
REGISTRY.register(collector)
|
||||
except ValueError:
|
||||
logger.debug("Collector already registered: %s", type(collector).__name__)
|
||||
253
backend/onyx/server/metrics/indexing_task_metrics.py
Normal file
253
backend/onyx/server/metrics/indexing_task_metrics.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Per-connector Prometheus metrics for indexing tasks.
|
||||
|
||||
Enriches the two primary indexing tasks (docfetching_proxy_task and
|
||||
docprocessing_task) with connector-level labels: source, tenant_id,
|
||||
and cc_pair_id.
|
||||
|
||||
Note: connector_name is intentionally excluded from push-based per-task
|
||||
counters because it is a user-defined free-form string that can create
|
||||
unbounded cardinality. The pull-based collectors on the monitoring worker
|
||||
(see indexing_pipeline.py) include connector_name since they have bounded
|
||||
cardinality (one series per connector, not per task execution).
|
||||
|
||||
Uses an in-memory cache for cc_pair_id → (source, name) lookups.
|
||||
Connectors never change source type, and names change rarely, so the
|
||||
cache is safe to hold for the worker's lifetime.
|
||||
|
||||
Usage in a worker app module:
|
||||
from onyx.server.metrics.indexing_task_metrics import (
|
||||
on_indexing_task_prerun,
|
||||
on_indexing_task_postrun,
|
||||
)
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from celery import Task
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.server.metrics.celery_task_metrics import _MAX_START_TIME_AGE_SECONDS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConnectorInfo:
|
||||
"""Cached connector metadata for metric labels."""
|
||||
|
||||
source: str
|
||||
name: str
|
||||
|
||||
|
||||
_UNKNOWN_CONNECTOR = ConnectorInfo(source="unknown", name="unknown")
|
||||
|
||||
# (tenant_id, cc_pair_id) → ConnectorInfo (populated on first encounter).
|
||||
# Keyed by tenant to avoid cross-tenant cache poisoning in multi-tenant
|
||||
# deployments where different tenants can share the same cc_pair_id value.
|
||||
_connector_cache: dict[tuple[str, int], ConnectorInfo] = {}
|
||||
|
||||
# Lock protecting _connector_cache — multiple thread-pool workers may
|
||||
# resolve connectors concurrently.
|
||||
_connector_cache_lock = threading.Lock()
|
||||
|
||||
# Only enrich these task types with per-connector labels
|
||||
_INDEXING_TASK_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
}
|
||||
)
|
||||
|
||||
# connector_name is intentionally excluded — see module docstring.
|
||||
INDEXING_TASK_STARTED = Counter(
|
||||
"onyx_indexing_task_started_total",
|
||||
"Indexing tasks started per connector",
|
||||
["task_name", "source", "tenant_id", "cc_pair_id"],
|
||||
)
|
||||
|
||||
INDEXING_TASK_COMPLETED = Counter(
|
||||
"onyx_indexing_task_completed_total",
|
||||
"Indexing tasks completed per connector",
|
||||
[
|
||||
"task_name",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"cc_pair_id",
|
||||
"outcome",
|
||||
],
|
||||
)
|
||||
|
||||
INDEXING_TASK_DURATION = Histogram(
|
||||
"onyx_indexing_task_duration_seconds",
|
||||
"Indexing task duration by connector type",
|
||||
["task_name", "source", "tenant_id"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
# task_id → monotonic start time (for indexing tasks only)
|
||||
_indexing_start_times: dict[str, float] = {}
|
||||
|
||||
# Lock protecting _indexing_start_times — prerun, postrun, and eviction may
|
||||
# run concurrently on thread-pool workers.
|
||||
_indexing_start_times_lock = threading.Lock()
|
||||
|
||||
|
||||
def _evict_stale_start_times() -> None:
|
||||
"""Remove _indexing_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
|
||||
|
||||
Must be called while holding _indexing_start_times_lock.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
stale_ids = [
|
||||
tid
|
||||
for tid, start in _indexing_start_times.items()
|
||||
if now - start > _MAX_START_TIME_AGE_SECONDS
|
||||
]
|
||||
for tid in stale_ids:
|
||||
_indexing_start_times.pop(tid, None)
|
||||
|
||||
|
||||
def _resolve_connector(cc_pair_id: int) -> ConnectorInfo:
|
||||
"""Resolve cc_pair_id to ConnectorInfo, using cache when possible.
|
||||
|
||||
On cache miss, does a single DB query with eager connector load.
|
||||
On any failure, returns _UNKNOWN_CONNECTOR without caching, so that
|
||||
subsequent calls can retry the lookup once the DB is available.
|
||||
|
||||
Note on tenant_id source: we read CURRENT_TENANT_ID_CONTEXTVAR for the
|
||||
cache key. The Celery tenant-aware middleware sets this contextvar before
|
||||
task execution, and it always matches kwargs["tenant_id"] (which is set
|
||||
at task dispatch time). They are guaranteed to agree for a given task
|
||||
execution context.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get("") or ""
|
||||
cache_key = (tenant_id, cc_pair_id)
|
||||
|
||||
with _connector_cache_lock:
|
||||
cached = _connector_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session,
|
||||
cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
# DB lookup succeeded but cc_pair doesn't exist — don't cache,
|
||||
# it may appear later (race with connector creation).
|
||||
return _UNKNOWN_CONNECTOR
|
||||
|
||||
info = ConnectorInfo(
|
||||
source=cc_pair.connector.source.value,
|
||||
name=cc_pair.name,
|
||||
)
|
||||
with _connector_cache_lock:
|
||||
_connector_cache[cache_key] = info
|
||||
return info
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to resolve connector info for cc_pair_id={cc_pair_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
return _UNKNOWN_CONNECTOR
|
||||
|
||||
|
||||
def on_indexing_task_prerun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
kwargs: dict | None,
|
||||
) -> None:
|
||||
"""Record per-connector metrics at task start.
|
||||
|
||||
Only fires for tasks in _INDEXING_TASK_NAMES. Silently returns for
|
||||
all other tasks.
|
||||
"""
|
||||
if task is None or task_id is None or kwargs is None:
|
||||
return
|
||||
|
||||
task_name = task.name or ""
|
||||
if task_name not in _INDEXING_TASK_NAMES:
|
||||
return
|
||||
|
||||
try:
|
||||
cc_pair_id = kwargs.get("cc_pair_id")
|
||||
tenant_id = str(kwargs.get("tenant_id", "unknown"))
|
||||
|
||||
if cc_pair_id is None:
|
||||
return
|
||||
|
||||
info = _resolve_connector(cc_pair_id)
|
||||
|
||||
INDEXING_TASK_STARTED.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
).inc()
|
||||
|
||||
with _indexing_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_indexing_start_times[task_id] = time.monotonic()
|
||||
except Exception:
|
||||
logger.debug("Failed to record indexing task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_indexing_task_postrun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
kwargs: dict | None,
|
||||
state: str | None,
|
||||
) -> None:
|
||||
"""Record per-connector completion metrics.
|
||||
|
||||
Only fires for tasks in _INDEXING_TASK_NAMES.
|
||||
"""
|
||||
if task is None or task_id is None or kwargs is None:
|
||||
return
|
||||
|
||||
task_name = task.name or ""
|
||||
if task_name not in _INDEXING_TASK_NAMES:
|
||||
return
|
||||
|
||||
try:
|
||||
cc_pair_id = kwargs.get("cc_pair_id")
|
||||
tenant_id = str(kwargs.get("tenant_id", "unknown"))
|
||||
|
||||
if cc_pair_id is None:
|
||||
return
|
||||
|
||||
info = _resolve_connector(cc_pair_id)
|
||||
outcome = "success" if state == "SUCCESS" else "failure"
|
||||
|
||||
INDEXING_TASK_COMPLETED.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
outcome=outcome,
|
||||
).inc()
|
||||
|
||||
with _indexing_start_times_lock:
|
||||
start = _indexing_start_times.pop(task_id, None)
|
||||
if start is not None:
|
||||
INDEXING_TASK_DURATION.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
).observe(time.monotonic() - start)
|
||||
except Exception:
|
||||
logger.debug("Failed to record indexing task postrun metrics", exc_info=True)
|
||||
89
backend/onyx/server/metrics/metrics_server.py
Normal file
89
backend/onyx/server/metrics/metrics_server.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Standalone Prometheus metrics HTTP server for non-API processes.
|
||||
|
||||
The FastAPI API server already exposes /metrics via prometheus-fastapi-instrumentator.
|
||||
Celery workers and other background processes use this module to expose their
|
||||
own /metrics endpoint on a configurable port.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
start_metrics_server("monitoring") # reads port from env or uses default
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default ports for worker types that serve custom Prometheus metrics.
|
||||
# Only add entries here when a worker actually registers collectors.
|
||||
# In k8s each worker type runs in its own pod, so PROMETHEUS_METRICS_PORT
|
||||
# env var can override.
|
||||
_DEFAULT_PORTS: dict[str, int] = {
|
||||
"monitoring": 9096,
|
||||
"docfetching": 9092,
|
||||
"docprocessing": 9093,
|
||||
}
|
||||
|
||||
_server_started = False
|
||||
_server_lock = threading.Lock()
|
||||
|
||||
|
||||
def start_metrics_server(worker_type: str) -> int | None:
|
||||
"""Start a Prometheus metrics HTTP server in a background thread.
|
||||
|
||||
Returns the port if started, None if disabled or already started.
|
||||
|
||||
Port resolution order:
|
||||
1. PROMETHEUS_METRICS_PORT env var (explicit override)
|
||||
2. Default port for the worker type
|
||||
3. If worker type is unknown and no env var, skip
|
||||
|
||||
Set PROMETHEUS_METRICS_ENABLED=false to disable.
|
||||
"""
|
||||
global _server_started
|
||||
|
||||
with _server_lock:
|
||||
if _server_started:
|
||||
logger.debug(f"Metrics server already started for {worker_type}")
|
||||
return None
|
||||
|
||||
enabled = os.environ.get("PROMETHEUS_METRICS_ENABLED", "true").lower()
|
||||
if enabled in ("false", "0", "no"):
|
||||
logger.info(f"Prometheus metrics server disabled for {worker_type}")
|
||||
return None
|
||||
|
||||
port_str = os.environ.get("PROMETHEUS_METRICS_PORT")
|
||||
if port_str:
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PROMETHEUS_METRICS_PORT '{port_str}' for {worker_type}, "
|
||||
"must be a numeric port. Skipping metrics server."
|
||||
)
|
||||
return None
|
||||
elif worker_type in _DEFAULT_PORTS:
|
||||
port = _DEFAULT_PORTS[worker_type]
|
||||
else:
|
||||
logger.info(
|
||||
f"No default metrics port for worker type '{worker_type}' "
|
||||
"and PROMETHEUS_METRICS_PORT not set. Skipping metrics server."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
start_http_server(port)
|
||||
_server_started = True
|
||||
logger.info(
|
||||
f"Prometheus metrics server started on :{port} for {worker_type}"
|
||||
)
|
||||
return port
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"Failed to start metrics server on :{port} for {worker_type}: {e}"
|
||||
)
|
||||
return None
|
||||
106
backend/onyx/server/metrics/opensearch_search.py
Normal file
106
backend/onyx/server/metrics/opensearch_search.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Prometheus metrics for OpenSearch search latency and throughput.
|
||||
|
||||
Tracks client-side round-trip latency, server-side execution time (from
|
||||
OpenSearch's ``took`` field), total search count, and in-flight concurrency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SEARCH_LATENCY_BUCKETS = (
|
||||
0.005,
|
||||
0.01,
|
||||
0.025,
|
||||
0.05,
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
1.0,
|
||||
2.5,
|
||||
5.0,
|
||||
10.0,
|
||||
25.0,
|
||||
)
|
||||
|
||||
_client_duration = Histogram(
|
||||
"onyx_opensearch_search_client_duration_seconds",
|
||||
"Client-side end-to-end latency of OpenSearch search calls",
|
||||
["search_type"],
|
||||
buckets=_SEARCH_LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
_server_duration = Histogram(
|
||||
"onyx_opensearch_search_server_duration_seconds",
|
||||
"Server-side execution time reported by OpenSearch (took field)",
|
||||
["search_type"],
|
||||
buckets=_SEARCH_LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
_search_total = Counter(
|
||||
"onyx_opensearch_search_total",
|
||||
"Total number of search requests sent to OpenSearch",
|
||||
["search_type"],
|
||||
)
|
||||
|
||||
_searches_in_progress = Gauge(
|
||||
"onyx_opensearch_searches_in_progress",
|
||||
"Number of OpenSearch searches currently in-flight",
|
||||
["search_type"],
|
||||
)
|
||||
|
||||
|
||||
def observe_opensearch_search(
|
||||
search_type: OpenSearchSearchType,
|
||||
client_duration_s: float,
|
||||
server_took_ms: int | None,
|
||||
) -> None:
|
||||
"""Records latency and throughput metrics for a completed OpenSearch search.
|
||||
|
||||
Args:
|
||||
search_type: The type of search.
|
||||
client_duration_s: Wall-clock duration measured on the client side, in
|
||||
seconds.
|
||||
server_took_ms: The ``took`` value from the OpenSearch response, in
|
||||
milliseconds. May be ``None`` if the response did not include it.
|
||||
"""
|
||||
try:
|
||||
label = search_type.value
|
||||
_search_total.labels(search_type=label).inc()
|
||||
_client_duration.labels(search_type=label).observe(client_duration_s)
|
||||
if server_took_ms is not None:
|
||||
_server_duration.labels(search_type=label).observe(server_took_ms / 1000.0)
|
||||
except Exception:
|
||||
logger.warning("Failed to record OpenSearch search metrics.", exc_info=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def track_opensearch_search_in_progress(
|
||||
search_type: OpenSearchSearchType,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager that tracks in-flight OpenSearch searches via a Gauge."""
|
||||
incremented = False
|
||||
label = search_type.value
|
||||
try:
|
||||
_searches_in_progress.labels(search_type=label).inc()
|
||||
incremented = True
|
||||
except Exception:
|
||||
logger.warning("Failed to increment in-progress search gauge.", exc_info=True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if incremented:
|
||||
try:
|
||||
_searches_in_progress.labels(search_type=label).dec()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decrement in-progress search gauge.", exc_info=True
|
||||
)
|
||||
@@ -41,6 +41,21 @@ class MessageResponseIDInfo(BaseModel):
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class ModelResponseSlot(BaseModel):
|
||||
"""Pairs a reserved assistant message ID with its model display name."""
|
||||
|
||||
message_id: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class MultiModelMessageResponseIDInfo(BaseModel):
|
||||
"""Sent at the start of a multi-model streaming response.
|
||||
Contains the user message ID and one slot per model being run in parallel."""
|
||||
|
||||
user_message_id: int | None
|
||||
responses: list[ModelResponseSlot]
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
source: DocumentSource
|
||||
|
||||
@@ -86,6 +101,9 @@ class SendMessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
llm_override: LLMOverride | None = None
|
||||
# For multi-model mode: up to 3 LLM overrides to run in parallel.
|
||||
# When provided with >1 entry, triggers multi-model streaming.
|
||||
llm_overrides: list[LLMOverride] | None = None
|
||||
# Test-only override for deterministic LiteLLM mock responses.
|
||||
mock_llm_response: str | None = None
|
||||
|
||||
@@ -211,6 +229,8 @@ class ChatMessageDetail(BaseModel):
|
||||
error: str | None = None
|
||||
current_feedback: str | None = None # "like" | "dislike" | null
|
||||
processing_duration_seconds: float | None = None
|
||||
preferred_response_id: int | None = None
|
||||
model_display_name: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -218,6 +238,11 @@ class ChatMessageDetail(BaseModel):
|
||||
return initial_dict
|
||||
|
||||
|
||||
class SetPreferredResponseRequest(BaseModel):
|
||||
user_message_id: int
|
||||
preferred_response_id: int
|
||||
|
||||
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: UUID
|
||||
description: str | None
|
||||
|
||||
@@ -8,3 +8,5 @@ class Placement(BaseModel):
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -9,7 +9,9 @@ from onyx import __version__ as onyx_version
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
@@ -17,9 +19,16 @@ from onyx.db.models import User
|
||||
from onyx.db.notification import dismiss_all_notifications
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Notification
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.models import UserSettings
|
||||
@@ -40,6 +49,15 @@ basic_router = APIRouter(prefix="/settings")
|
||||
def admin_put_settings(
|
||||
settings: Settings, _: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb is not None
|
||||
and settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
|
||||
)
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
@@ -80,7 +98,18 @@ def fetch_settings(
|
||||
needs_reindexing=needs_reindexing,
|
||||
onyx_craft_enabled=onyx_craft_enabled_for_user,
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
),
|
||||
default_file_token_count_threshold_k=(
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,12 +2,19 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -78,7 +85,12 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
user_file_max_upload_size_mb: int | None = Field(
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
@@ -104,5 +116,18 @@ class UserSettings(Settings):
|
||||
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
|
||||
# document sets are unavailable.
|
||||
vector_db_enabled: bool = True
|
||||
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
|
||||
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
# Factory defaults so the frontend can show a "restore default" button.
|
||||
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
default_file_token_count_threshold_k: int = Field(
|
||||
default_factory=lambda: (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -51,9 +57,36 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
|
||||
# Resolve context-aware defaults for token threshold.
|
||||
# None = admin hasn't set a value yet → use context-aware default.
|
||||
# 0 = admin explicitly chose "no limit" → preserve as-is.
|
||||
if settings.file_token_count_threshold_k is None:
|
||||
settings.file_token_count_threshold_k = (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
# Upload size: 0 and None are treated as "unset" (not "no limit") →
|
||||
# fall back to min(configured default, hard ceiling).
|
||||
if not settings.user_file_max_upload_size_mb:
|
||||
settings.user_file_max_upload_size_mb = min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
)
|
||||
|
||||
# Clamp to env ceiling so stale KV values are capped even if the
|
||||
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
|
||||
# was already saved (api.py only guards new writes).
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -736,7 +736,7 @@ if __name__ == "__main__":
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
persona = get_default_behavior_persona(db_session)
|
||||
persona = get_default_behavior_persona(db_session, eager_load_for_tools=True)
|
||||
if persona is None:
|
||||
raise ValueError("No default persona found")
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -124,7 +125,12 @@ def construct_tools(
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs.
|
||||
|
||||
Will simply skip tools that are not allowed/available."""
|
||||
Will simply skip tools that are not allowed/available.
|
||||
|
||||
Callers must supply a persona with ``tools``, ``document_sets``,
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
@@ -143,6 +149,28 @@ def construct_tools(
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
|
||||
def _build_search_tool(tool_id: int, config: SearchToolConfig) -> SearchTool:
|
||||
persona_search_info = PersonaSearchInfo(
|
||||
document_set_names=[ds.name for ds in persona.document_sets],
|
||||
search_start_date=persona.search_start_date,
|
||||
attached_document_ids=[doc.id for doc in persona.attached_documents],
|
||||
hierarchy_node_ids=[node.id for node in persona.hierarchy_nodes],
|
||||
)
|
||||
return SearchTool(
|
||||
tool_id=tool_id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona_search_info=persona_search_info,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=config.user_selected_filters,
|
||||
project_id_filter=config.project_id_filter,
|
||||
persona_id_filter=config.persona_id_filter,
|
||||
bypass_acl=config.bypass_acl,
|
||||
slack_context=config.slack_context,
|
||||
enable_slack_search=config.enable_slack_search,
|
||||
)
|
||||
|
||||
added_search_tool = False
|
||||
for db_tool_model in persona.tools:
|
||||
# If allowed_tool_ids is specified, skip tools not in the allowed list
|
||||
@@ -176,22 +204,9 @@ def construct_tools(
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=db_tool_model.id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
tool_dict[db_tool_model.id] = [
|
||||
_build_search_tool(db_tool_model.id, search_tool_config)
|
||||
]
|
||||
|
||||
# Handle Image Generation Tool
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
@@ -421,26 +436,12 @@ def construct_tools(
|
||||
# Get the database tool model for SearchTool
|
||||
search_tool_db_model = get_builtin_tool(db_session, SearchTool)
|
||||
|
||||
# Use the passed-in config if available, otherwise create a new one
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=search_tool_db_model.id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
)
|
||||
|
||||
tool_dict[search_tool_db_model.id] = [search_tool]
|
||||
tool_dict[search_tool_db_model.id] = [
|
||||
_build_search_tool(search_tool_db_model.id, search_tool_config)
|
||||
]
|
||||
|
||||
# Always inject MemoryTool when the user has the memory tool enabled,
|
||||
# bypassing persona tool associations and allowed_tool_ids filtering
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
@@ -65,7 +66,6 @@ from onyx.db.federated import (
|
||||
get_federated_connector_document_set_mappings_by_document_set_names,
|
||||
)
|
||||
from onyx.db.federated import list_federated_connector_oauth_tokens
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -238,8 +238,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
emitter: Emitter,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Used for filter settings
|
||||
persona: Persona,
|
||||
# Pre-extracted persona search configuration
|
||||
persona_search_info: PersonaSearchInfo,
|
||||
llm: LLM,
|
||||
document_index: DocumentIndex,
|
||||
# Respecting user selections
|
||||
@@ -258,7 +258,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
super().__init__(emitter=emitter)
|
||||
|
||||
self.user = user
|
||||
self.persona = persona
|
||||
self.persona_search_info = persona_search_info
|
||||
self.llm = llm
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
@@ -289,7 +289,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Case 1: Slack bot context — requires a Slack federated connector
|
||||
# linked via the persona's document sets
|
||||
if self.slack_context:
|
||||
document_set_names = [ds.name for ds in self.persona.document_sets]
|
||||
document_set_names = self.persona_search_info.document_set_names
|
||||
if not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping Slack federated search: no document sets on persona"
|
||||
@@ -463,7 +463,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
persona_id_filter=self.persona_id_filter,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
persona_search_info=self.persona_search_info,
|
||||
acl_filters=acl_filters,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=federated_retrieval_infos,
|
||||
@@ -587,15 +587,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
and self.user_selected_filters.source_type
|
||||
else None
|
||||
)
|
||||
persona_document_sets = (
|
||||
[ds.name for ds in self.persona.document_sets] if self.persona else None
|
||||
)
|
||||
federated_retrieval_infos = (
|
||||
get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=self.user.id if self.user else None,
|
||||
source_types=prefetch_source_types,
|
||||
document_set_names=persona_document_sets,
|
||||
document_set_names=self.persona_search_info.document_set_names,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
@@ -549,7 +549,7 @@ mypy-extensions==1.0.0
|
||||
# typing-inspect
|
||||
nest-asyncio==1.6.0
|
||||
# via onyx
|
||||
nltk==3.9.3
|
||||
nltk==3.9.4
|
||||
# via unstructured
|
||||
numpy==2.4.1
|
||||
# via
|
||||
@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.9.1
|
||||
pypdf==6.9.2
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -861,7 +861,7 @@ regex==2025.11.3
|
||||
# dateparser
|
||||
# nltk
|
||||
# tiktoken
|
||||
requests==2.32.5
|
||||
requests==2.33.0
|
||||
# via
|
||||
# atlassian-python-api
|
||||
# braintrust
|
||||
|
||||
@@ -263,7 +263,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.1
|
||||
onyx-devtools==0.7.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -410,7 +410,7 @@ release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
requests==2.32.5
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
|
||||
@@ -244,7 +244,7 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
requests==2.32.5
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
|
||||
@@ -338,7 +338,7 @@ regex==2025.11.3
|
||||
# via
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.5
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-genai
|
||||
|
||||
@@ -191,25 +191,6 @@ IGNORED_SYNCING_TENANT_LIST = (
|
||||
else None
|
||||
)
|
||||
|
||||
# Global flag to skip userfile threshold for all users/tenants
|
||||
SKIP_USERFILE_THRESHOLD = (
|
||||
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
|
||||
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
|
||||
)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
|
||||
[
|
||||
tenant.strip()
|
||||
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
|
||||
if tenant.strip()
|
||||
]
|
||||
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user