mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-30 12:02:42 +00:00
Compare commits
141 Commits
jamison/tr
...
ref-file-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d2d79127d | ||
|
|
f44663c23c | ||
|
|
b73d26aedd | ||
|
|
146b5449d2 | ||
|
|
b66991b5c5 | ||
|
|
9cb76dc027 | ||
|
|
f66891d19e | ||
|
|
c07c952ad5 | ||
|
|
be7f40a28a | ||
|
|
26f941b5da | ||
|
|
b9e84c42a8 | ||
|
|
0a1df52c2f | ||
|
|
306b0d452f | ||
|
|
5fdb34ba8e | ||
|
|
2d066631e3 | ||
|
|
5c84f6c61b | ||
|
|
899179d4b6 | ||
|
|
80d6bafc74 | ||
|
|
2cc325cb0e | ||
|
|
849385b756 | ||
|
|
417b9c12e4 | ||
|
|
30b37d0a77 | ||
|
|
b48be0cd3a | ||
|
|
127fd90424 | ||
|
|
f9c9e55f32 | ||
|
|
5afcf1acea | ||
|
|
eb1244a9d7 | ||
|
|
2433a9a4c5 | ||
|
|
60bc8fcac6 | ||
|
|
1ddc958a51 | ||
|
|
de37acbe07 | ||
|
|
08cd2f2c3e | ||
|
|
fc29f20914 | ||
|
|
c43cb80a7a | ||
|
|
56f0be2ec8 | ||
|
|
42f9ddf247 | ||
|
|
a10a85c73c | ||
|
|
31d8ae9718 | ||
|
|
00a0a99842 | ||
|
|
90040f8973 | ||
|
|
4f5d081f26 | ||
|
|
c51a6dbd0d | ||
|
|
8b90ecc189 | ||
|
|
865c893a09 | ||
|
|
ef5628bfa7 | ||
|
|
6ffee0021e | ||
|
|
28dc84b831 | ||
|
|
230f035500 | ||
|
|
55b24d72b4 | ||
|
|
3321a84c7d | ||
|
|
54bf32a5f8 | ||
|
|
4bb6b76be6 | ||
|
|
db94562474 | ||
|
|
582d4642c1 | ||
|
|
3caaecdb0e | ||
|
|
039b69806b | ||
|
|
63971d4958 | ||
|
|
ffd897f380 | ||
|
|
4745069232 | ||
|
|
386782f188 | ||
|
|
ff009c4129 | ||
|
|
b20a5ebf69 | ||
|
|
8645adb807 | ||
|
|
2425bd4d8d | ||
|
|
333b2b19cb | ||
|
|
44895b3bd6 | ||
|
|
78c2ecf99f | ||
|
|
e3e0e04edc | ||
|
|
a19fe03bd8 | ||
|
|
415c05b5f8 | ||
|
|
352fd19f0a | ||
|
|
41ae039bfa | ||
|
|
782c734287 | ||
|
|
728cdb0715 | ||
|
|
baf6437117 | ||
|
|
f187165077 | ||
|
|
727be3d663 | ||
|
|
98c8f9884b | ||
|
|
d79a068984 | ||
|
|
ba0740d15f | ||
|
|
86b7bed90b | ||
|
|
aead6ab9a5 | ||
|
|
c9d4c186dd | ||
|
|
70aad1ec46 | ||
|
|
ca3cc16ead | ||
|
|
9ea1780ce5 | ||
|
|
f70e5e605e | ||
|
|
84b134e226 | ||
|
|
b17c63a7d6 | ||
|
|
76c41d1b0b | ||
|
|
579b86f1ce | ||
|
|
a53cf13db1 | ||
|
|
06f76bac0c | ||
|
|
e59b0c0d23 | ||
|
|
dfa37cce8b | ||
|
|
6dce6b09e4 | ||
|
|
0eba41c487 | ||
|
|
a426930123 | ||
|
|
73d98c7fa5 | ||
|
|
a096cf3997 | ||
|
|
1e01ff8f10 | ||
|
|
3665bb23c2 | ||
|
|
8ff0d5fc15 | ||
|
|
645d45776a | ||
|
|
fa06e4ebd5 | ||
|
|
2a61e3ce4c | ||
|
|
d8733dd89f | ||
|
|
c651177529 | ||
|
|
3193fe76e4 | ||
|
|
b9025c57f6 | ||
|
|
1b0d62c16e | ||
|
|
7a0c977eb7 | ||
|
|
a9c04dca89 | ||
|
|
8d76a95c86 | ||
|
|
2a69561ec5 | ||
|
|
e1655426d6 | ||
|
|
9634266bc6 | ||
|
|
0c962d882d | ||
|
|
4bf3edf83b | ||
|
|
c48a77c644 | ||
|
|
26d70ab16b | ||
|
|
f8a2f3ac93 | ||
|
|
5186356a26 | ||
|
|
7b826e2a4e | ||
|
|
c175dc8f6a | ||
|
|
aa11813cc0 | ||
|
|
6235f49b49 | ||
|
|
fd6a110794 | ||
|
|
bd42c459d6 | ||
|
|
aede532e63 | ||
|
|
068ac543ad | ||
|
|
30e7a831a5 | ||
|
|
276261c96d | ||
|
|
205f1410e4 | ||
|
|
a93d154c27 | ||
|
|
1361879bd0 | ||
|
|
c58cc320b2 | ||
|
|
461350958a | ||
|
|
50dde0be1a | ||
|
|
199e1df453 | ||
|
|
996b674840 |
@@ -6,3 +6,4 @@
|
||||
|
||||
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
|
||||
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190
|
||||
7b927e79c25f4ddfd18a067f489e122acd2c89de # chore(format): format files where `ruff` and `black` agree (#9339)
|
||||
|
||||
10
.github/workflows/deployment.yml
vendored
10
.github/workflows/deployment.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
@@ -307,7 +307,7 @@ jobs:
|
||||
xdg-utils
|
||||
|
||||
- name: setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6.2.0
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v6.3.0
|
||||
with:
|
||||
node-version: 24
|
||||
package-manager-cache: false
|
||||
@@ -615,6 +615,7 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1263,8 +1264,6 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=craft-latest
|
||||
# TODO: Consider aligning craft-latest tags with regular backend builds (e.g., latest, edge, beta)
|
||||
# to keep tagging strategy consistent across all backend images
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
@@ -1488,6 +1487,7 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f
|
||||
with:
|
||||
node-version: 24
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
|
||||
@@ -7,6 +7,15 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "backend/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- ".github/workflows/pr-external-dependency-unit-tests.yml"
|
||||
- ".github/actions/setup-python-and-install-dependencies/**"
|
||||
- ".github/actions/setup-playwright/**"
|
||||
- "deployment/docker_compose/docker-compose.yml"
|
||||
- "deployment/docker_compose/docker-compose.dev.yml"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
|
||||
|
||||
6
.github/workflows/pr-playwright-tests.yml
vendored
6
.github/workflows/pr-playwright-tests.yml
vendored
@@ -272,7 +272,7 @@ jobs:
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
@@ -471,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -614,7 +614,7 @@ jobs:
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
|
||||
@@ -7,6 +7,13 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "backend/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
- ".github/workflows/pr-python-connector-tests.yml"
|
||||
- ".github/actions/setup-python-and-install-dependencies/**"
|
||||
- ".github/actions/setup-playwright/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
|
||||
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
|
||||
uses: docker/bake-action@82490499d2e5613fcead7e128237ef0b0ea210f7 # ratchet:docker/bake-action@v7.0.0
|
||||
env:
|
||||
TAG: model-server-${{ github.run_id }}
|
||||
with:
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
2
.github/workflows/preview.yml
vendored
2
.github/workflows/preview.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
2
.github/workflows/release-cli.yml
vendored
2
.github/workflows/release-cli.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/storybook-deploy.yml
vendored
2
.github/workflows/storybook-deploy.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
64
.greptile/config.json
Normal file
64
.greptile/config.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
57
.greptile/files.json
Normal file
57
.greptile/files.json
Normal file
@@ -0,0 +1,57 @@
|
||||
[
|
||||
{
|
||||
"scope": [],
|
||||
"path": "contributing_guides/best_practices.md",
|
||||
"description": "Best practices for contributing to the codebase"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/AGENTS.md",
|
||||
"description": "Frontend coding standards for the web directory"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/tests/README.md",
|
||||
"description": "Frontend testing guide and conventions"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/CLAUDE.md",
|
||||
"description": "Single source of truth for frontend coding standards"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/lib/opal/README.md",
|
||||
"description": "Opal component library usage guide"
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**"],
|
||||
"path": "backend/tests/README.md",
|
||||
"description": "Backend testing guide covering all 4 test types, fixtures, and conventions"
|
||||
},
|
||||
{
|
||||
"scope": ["backend/onyx/connectors/**"],
|
||||
"path": "backend/onyx/connectors/README.md",
|
||||
"description": "Connector development guide covering design, interfaces, and required changes"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "backend/alembic/README.md",
|
||||
"description": "Migration guidance, including multi-tenant migration behavior"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/helm/charts/onyx/values-lite.yaml",
|
||||
"description": "Lite deployment Helm values and service assumptions"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
|
||||
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
|
||||
}
|
||||
]
|
||||
39
.greptile/rules.md
Normal file
39
.greptile/rules.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Greptile Review Rules
|
||||
|
||||
## Type Annotations
|
||||
|
||||
Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code.
|
||||
|
||||
## Best Practices
|
||||
|
||||
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
|
||||
|
||||
## TODOs
|
||||
|
||||
Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of `TODO(name): ...` or `TODO(1234): ...`
|
||||
|
||||
## Debugging Code
|
||||
|
||||
Remove temporary debugging code before merging to production, especially tenant-specific debugging logs.
|
||||
|
||||
## Hardcoded Booleans
|
||||
|
||||
When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant.
|
||||
|
||||
## Multi-tenant vs Single-tenant
|
||||
|
||||
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
|
||||
|
||||
## Nginx Routing — New Backend Routes
|
||||
|
||||
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
|
||||
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
|
||||
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
|
||||
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
|
||||
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
|
||||
|
||||
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
|
||||
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
@@ -117,7 +117,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
"consoleTitle": "API Server Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
@@ -268,7 +269,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
"consoleTitle": "Celery heavy Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery kg_processing",
|
||||
@@ -355,7 +357,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user_file_processing Console"
|
||||
"consoleTitle": "Celery user_file_processing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
@@ -413,7 +416,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console"
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
|
||||
279
AGENTS.md
279
AGENTS.md
@@ -167,284 +167,7 @@ web/
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ RUN apt-get update && \
|
||||
gcc \
|
||||
nano \
|
||||
vim \
|
||||
# Install procps so kubernetes exec sessions can use ps aux for debugging
|
||||
procps \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""remove voice_provider deleted column
|
||||
|
||||
Revision ID: 1d78c0ca7853
|
||||
Revises: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-26 11:30:53.883127
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d78c0ca7853"
|
||||
down_revision = "a3f8b2c1d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Hard-delete any soft-deleted rows before dropping the column
|
||||
op.execute("DELETE FROM voice_provider WHERE deleted = true")
|
||||
op.drop_column("voice_provider", "deleted")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"voice_provider",
|
||||
sa.Column(
|
||||
"deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
"""group_permissions_phase1
|
||||
|
||||
Revision ID: 25a5501dc766
|
||||
Revises: b728689f45b1
|
||||
Create Date: 2026-03-23 11:41:25.557442
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "25a5501dc766"
|
||||
down_revision = "b728689f45b1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Add account_type column to user table (nullable for now).
|
||||
# TODO(subash): backfill account_type for existing rows and add NOT NULL.
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"account_type",
|
||||
sa.Enum(AccountType, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 2. Add is_default column to user_group table
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"is_default",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Create permission_grant table
|
||||
op.create_table(
|
||||
"permission_grant",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"permission",
|
||||
sa.Enum(Permission, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"grant_source",
|
||||
sa.Enum(GrantSource, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["group_id"],
|
||||
["user_group.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["granted_by"],
|
||||
["user.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Index on user__user_group(user_id) — existing composite PK
|
||||
# has user_group_id as leading column; user-filtered queries need this
|
||||
op.create_index(
|
||||
"ix_user__user_group_user_id",
|
||||
"user__user_group",
|
||||
["user_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_user__user_group_user_id", table_name="user__user_group")
|
||||
op.drop_table("permission_grant")
|
||||
op.drop_column("user_group", "is_default")
|
||||
op.drop_column("user", "account_type")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add preferred_response_id and model_display_name to chat_message
|
||||
|
||||
Revision ID: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-22
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3f8b2c1d4e5"
|
||||
down_revision = "25a5501dc766"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"preferred_response_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("model_display_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "model_display_name")
|
||||
op.drop_column("chat_message", "preferred_response_id")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""rename persona is_visible to is_listed and featured to is_featured
|
||||
|
||||
Revision ID: b728689f45b1
|
||||
Revises: 689433b0d8de
|
||||
Create Date: 2026-03-23 12:36:26.607305
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b728689f45b1"
|
||||
down_revision = "689433b0d8de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("persona", "is_visible", new_column_name="is_listed")
|
||||
op.alter_column("persona", "featured", new_column_name="is_featured")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("persona", "is_listed", new_column_name="is_visible")
|
||||
op.alter_column("persona", "is_featured", new_column_name="featured")
|
||||
@@ -36,6 +36,56 @@ TABLES_WITH_USER_ID = [
|
||||
]
|
||||
|
||||
|
||||
def _dedupe_null_notifications(connection: sa.Connection) -> None:
|
||||
# Multiple NULL-owned notifications can exist because the unique index treats
|
||||
# NULL user_id values as distinct. Before migrating them to the anonymous
|
||||
# user, collapse duplicates and remove rows that would conflict with an
|
||||
# already-existing anonymous notification.
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
WITH ranked_null_notifications AS (
|
||||
SELECT
|
||||
id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY notif_type, COALESCE(additional_data, '{}'::jsonb)
|
||||
ORDER BY first_shown DESC, last_shown DESC, id DESC
|
||||
) AS row_num
|
||||
FROM notification
|
||||
WHERE user_id IS NULL
|
||||
)
|
||||
DELETE FROM notification
|
||||
WHERE id IN (
|
||||
SELECT id
|
||||
FROM ranked_null_notifications
|
||||
WHERE row_num > 1
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
print(f"Deleted {result.rowcount} duplicate NULL-owned notifications")
|
||||
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM notification AS null_owned
|
||||
USING notification AS anonymous_owned
|
||||
WHERE null_owned.user_id IS NULL
|
||||
AND anonymous_owned.user_id = :user_id
|
||||
AND null_owned.notif_type = anonymous_owned.notif_type
|
||||
AND COALESCE(null_owned.additional_data, '{}'::jsonb) =
|
||||
COALESCE(anonymous_owned.additional_data, '{}'::jsonb)
|
||||
"""
|
||||
),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
print(
|
||||
f"Deleted {result.rowcount} NULL-owned notifications that conflict with existing anonymous-owned notifications"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Create the anonymous user for anonymous access feature.
|
||||
@@ -65,7 +115,12 @@ def upgrade() -> None:
|
||||
|
||||
# Migrate any remaining user_id=NULL records to anonymous user
|
||||
for table in TABLES_WITH_USER_ID:
|
||||
try:
|
||||
# Dedup notifications outside the savepoint so deletions persist
|
||||
# even if the subsequent UPDATE rolls back
|
||||
if table == "notification":
|
||||
_dedupe_null_notifications(connection)
|
||||
|
||||
with connection.begin_nested():
|
||||
# Exclude public credential (id=0) which must remain user_id=NULL
|
||||
# Exclude builtin tools (in_code_tool_id IS NOT NULL) which must remain user_id=NULL
|
||||
# Exclude builtin personas (builtin_persona=True) which must remain user_id=NULL
|
||||
@@ -80,6 +135,7 @@ def upgrade() -> None:
|
||||
condition = "user_id IS NULL AND is_public = false"
|
||||
else:
|
||||
condition = "user_id IS NULL"
|
||||
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
@@ -92,19 +148,19 @@ def upgrade() -> None:
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
print(f"Updated {result.rowcount} rows in {table} to anonymous user")
|
||||
except Exception as e:
|
||||
print(f"Skipping {table}: {e}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Set anonymous user's records back to NULL and delete the anonymous user.
|
||||
|
||||
Note: Duplicate NULL-owned notifications removed during upgrade are not restored.
|
||||
"""
|
||||
connection = op.get_bind()
|
||||
|
||||
# Set records back to NULL
|
||||
for table in TABLES_WITH_USER_ID:
|
||||
try:
|
||||
with connection.begin_nested():
|
||||
connection.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
@@ -115,8 +171,6 @@ def downgrade() -> None:
|
||||
),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Delete the anonymous user
|
||||
connection.execute(
|
||||
|
||||
@@ -473,6 +473,8 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -25,13 +25,13 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Default number of pre-provisioned tenants to maintain
|
||||
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||
# Maximum tenants to provision in a single task run.
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -58,7 +58,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
@@ -74,9 +74,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
num_available_tenants = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
num_minimum_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
if num_available_tenants < num_minimum_available_tenants:
|
||||
@@ -90,22 +88,46 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
f"To provision: {tenants_to_provision}"
|
||||
)
|
||||
|
||||
# just provision one tenant each time we run this ... increase if needed.
|
||||
if tenants_to_provision > 0:
|
||||
pre_provision_tenant()
|
||||
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
|
||||
if batch_size < tenants_to_provision:
|
||||
task_logger.info(
|
||||
f"Capping batch to {batch_size} "
|
||||
f"(need {tenants_to_provision}, will catch up next cycle)"
|
||||
)
|
||||
|
||||
provisioned = 0
|
||||
for i in range(batch_size):
|
||||
task_logger.info(f"Provisioning tenant {i + 1}/{batch_size}")
|
||||
try:
|
||||
if pre_provision_tenant():
|
||||
provisioned += 1
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, "
|
||||
"continuing with remaining tenants"
|
||||
)
|
||||
|
||||
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
finally:
|
||||
lock_check.release()
|
||||
try:
|
||||
lock_check.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release check lock (likely expired), continuing"
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> None:
|
||||
def pre_provision_tenant() -> bool:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
This function fully sets up the tenant with all necessary configurations,
|
||||
so it's ready to be assigned to a user immediately.
|
||||
|
||||
Returns True if a tenant was successfully provisioned, False otherwise.
|
||||
"""
|
||||
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
|
||||
# rather than inside this function
|
||||
@@ -113,15 +135,15 @@ def pre_provision_tenant() -> None:
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
if not lock_provision.acquire(blocking=False):
|
||||
task_logger.debug(
|
||||
"Skipping pre_provision_tenant task because it is already running"
|
||||
task_logger.warning(
|
||||
"Skipping pre_provision_tenant — could not acquire provision lock"
|
||||
)
|
||||
return
|
||||
return False
|
||||
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
@@ -161,6 +183,7 @@ def pre_provision_tenant() -> None:
|
||||
db_session.add(new_tenant)
|
||||
db_session.commit()
|
||||
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
|
||||
return True
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
task_logger.error(
|
||||
@@ -184,5 +207,11 @@ def pre_provision_tenant() -> None:
|
||||
asyncio.run(rollback_tenant_provisioning(tenant_id))
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
return False
|
||||
finally:
|
||||
lock_provision.release()
|
||||
try:
|
||||
lock_provision.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release provision lock (likely expired), continuing"
|
||||
)
|
||||
|
||||
@@ -115,8 +115,14 @@ def fetch_user_group_token_rate_limits_for_user(
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
stmt = select(TokenRateLimit)
|
||||
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
|
||||
stmt = (
|
||||
select(TokenRateLimit)
|
||||
.join(
|
||||
TokenRateLimit__UserGroup,
|
||||
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
||||
)
|
||||
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if enabled_only:
|
||||
|
||||
@@ -800,6 +800,33 @@ def update_user_group(
|
||||
return db_user_group
|
||||
|
||||
|
||||
def rename_user_group(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
new_name: str,
|
||||
) -> UserGroup:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
if db_user_group is None:
|
||||
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
db_user_group.name = new_name
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
# CC pair documents in Vespa contain the group name, so we need to
|
||||
# trigger a sync to update them with the new name.
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
|
||||
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
|
||||
@@ -250,20 +250,24 @@ def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
|
||||
raise e
|
||||
|
||||
|
||||
def _is_public_item(drive_item: DriveItem) -> bool:
|
||||
is_public = False
|
||||
def _is_public_item(
|
||||
drive_item: DriveItem,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> bool:
|
||||
if not treat_sharing_link_as_public:
|
||||
return False
|
||||
|
||||
try:
|
||||
permissions = sleep_and_retry(
|
||||
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
|
||||
)
|
||||
for permission in permissions:
|
||||
if permission.link and (
|
||||
permission.link.scope == "anonymous"
|
||||
or permission.link.scope == "organization"
|
||||
if permission.link and permission.link.scope in (
|
||||
"anonymous",
|
||||
"organization",
|
||||
):
|
||||
is_public = True
|
||||
break
|
||||
return is_public
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
|
||||
return False
|
||||
@@ -504,6 +508,7 @@ def get_external_access_from_sharepoint(
|
||||
drive_item: DriveItem | None,
|
||||
site_page: dict[str, Any] | None,
|
||||
add_prefix: bool = False,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get external access information from SharePoint.
|
||||
@@ -563,8 +568,7 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
|
||||
if drive_item and drive_name:
|
||||
# Here we check if the item have have any public links, if so we return early
|
||||
is_public = _is_public_item(drive_item)
|
||||
is_public = _is_public_item(drive_item, treat_sharing_link_as_public)
|
||||
if is_public:
|
||||
logger.info(f"Item {drive_item.id} is public")
|
||||
return ExternalAccess(
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
@@ -105,9 +106,11 @@ def _get_slack_document_access(
|
||||
slack_connector: SlackConnector,
|
||||
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
indexing_start: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback
|
||||
callback=callback,
|
||||
start=indexing_start,
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
@@ -180,9 +183,15 @@ def slack_doc_sync(
|
||||
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.set_credentials_provider(provider)
|
||||
indexing_start_ts: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
yield from _get_slack_document_access(
|
||||
slack_connector,
|
||||
slack_connector=slack_connector,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
indexing_start=indexing_start_ts,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -40,10 +41,19 @@ def generic_doc_sync(
|
||||
|
||||
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
|
||||
|
||||
indexing_start: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=indexing_start,
|
||||
callback=callback,
|
||||
):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -44,19 +44,21 @@ def _run_single_search(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
num_hits: int | None = None,
|
||||
hybrid_alpha: float | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
limit=num_hits,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona=None, # No persona for direct search
|
||||
persona_search_info=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -74,7 +76,7 @@ def stream_search_query(
|
||||
Core search function that yields streaming packets.
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
# Get document index.
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
@@ -119,6 +121,7 @@ def stream_search_query(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
num_hits=request.num_hits,
|
||||
hybrid_alpha=request.hybrid_alpha,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
@@ -133,6 +136,7 @@ def stream_search_query(
|
||||
user,
|
||||
db_session,
|
||||
request.num_hits,
|
||||
request.hybrid_alpha,
|
||||
),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
|
||||
@@ -157,7 +157,11 @@ def fetch_logo_helper(db_session: Session) -> Response: # noqa: ARG001
|
||||
detail="No logo file found",
|
||||
)
|
||||
else:
|
||||
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
|
||||
return Response(
|
||||
content=onyx_file.data,
|
||||
media_type=onyx_file.mime_type,
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response: # noqa: ARG001
|
||||
|
||||
@@ -27,15 +27,17 @@ class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any
|
||||
# changes to it should be reviewed and approved by an experienced team member.
|
||||
# It is very important to 1. avoid bloat and 2. that this remains backwards
|
||||
# compatible across versions.
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 30
|
||||
|
||||
hybrid_alpha: float | None = None
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
@@ -67,8 +68,10 @@ def search_flow_classification(
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any
|
||||
# changes to it should be reviewed and approved by an experienced team member.
|
||||
# It is very important to 1. avoid bloat and 2. that this remains backwards
|
||||
# compatible across versions.
|
||||
@router.post(
|
||||
"/send-search-message",
|
||||
response_model=None,
|
||||
@@ -80,13 +83,19 @@ def handle_send_search_message(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
Execute a search query with optional streaming.
|
||||
Executes a search query with optional streaming.
|
||||
|
||||
When stream=True: Returns StreamingResponse with SSE
|
||||
When stream=False: Returns SearchFullResponse
|
||||
If hybrid_alpha is unset and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
|
||||
is True, executes pure keyword search.
|
||||
|
||||
Returns:
|
||||
StreamingResponse with SSE if stream=True, otherwise SearchFullResponse.
|
||||
"""
|
||||
logger.debug(f"Received search query: {request.search_query}")
|
||||
|
||||
if request.hybrid_alpha is None and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH:
|
||||
request.hybrid_alpha = 0.0
|
||||
|
||||
# Non-streaming path
|
||||
if not request.stream:
|
||||
try:
|
||||
|
||||
@@ -178,7 +178,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
is_featured=persona.is_featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.persona import update_persona_access
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
@@ -11,13 +12,16 @@ from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import rename_user_group
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupRename
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
@@ -27,6 +31,9 @@ from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -87,6 +94,32 @@ def create_user_group(
|
||||
return UserGroup.from_model(db_user_group)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/rename")
|
||||
def rename_user_group_endpoint(
|
||||
rename_request: UserGroupRename,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
db_session=db_session,
|
||||
user_group_id=rename_request.id,
|
||||
new_name=rename_request.name,
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"User group with name '{rename_request.name}' already exists.",
|
||||
)
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "not found" in msg.lower():
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, msg)
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, msg)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/{user_group_id}")
|
||||
def patch_user_group(
|
||||
user_group_id: int,
|
||||
@@ -161,3 +194,38 @@ def delete_user_group(
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/{user_group_id}/agents")
|
||||
def update_group_agents(
|
||||
user_group_id: int,
|
||||
request: UpdateGroupAgentsRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
for agent_id in request.added_agent_ids:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=agent_id, user=user, db_session=db_session
|
||||
)
|
||||
current_group_ids = [g.id for g in persona.groups]
|
||||
if user_group_id not in current_group_ids:
|
||||
update_persona_access(
|
||||
persona_id=agent_id,
|
||||
creator_user_id=user.id,
|
||||
db_session=db_session,
|
||||
group_ids=current_group_ids + [user_group_id],
|
||||
)
|
||||
|
||||
for agent_id in request.removed_agent_ids:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=agent_id, user=user, db_session=db_session
|
||||
)
|
||||
current_group_ids = [g.id for g in persona.groups]
|
||||
update_persona_access(
|
||||
persona_id=agent_id,
|
||||
creator_user_id=user.id,
|
||||
db_session=db_session,
|
||||
group_ids=[gid for gid in current_group_ids if gid != user_group_id],
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -104,6 +104,16 @@ class AddUsersToUserGroupRequest(BaseModel):
|
||||
user_ids: list[UUID]
|
||||
|
||||
|
||||
class UserGroupRename(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class SetCuratorRequest(BaseModel):
|
||||
user_id: UUID
|
||||
is_curator: bool
|
||||
|
||||
|
||||
class UpdateGroupAgentsRequest(BaseModel):
|
||||
added_agent_ids: list[int]
|
||||
removed_agent_ids: list[int]
|
||||
|
||||
@@ -80,15 +80,45 @@ def capture_and_sync_with_alternate_posthog(
|
||||
logger.error(f"Error identifying cloud posthog user: {e}")
|
||||
|
||||
|
||||
def alias_user(distinct_id: str, anonymous_id: str) -> None:
|
||||
"""Link an anonymous distinct_id to an identified user, merging person profiles.
|
||||
|
||||
No-ops when the IDs match (e.g. returning users whose PostHog cookie
|
||||
already contains their identified user ID).
|
||||
"""
|
||||
if not posthog or anonymous_id == distinct_id:
|
||||
return
|
||||
|
||||
try:
|
||||
posthog.alias(previous_id=anonymous_id, distinct_id=distinct_id)
|
||||
posthog.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error aliasing PostHog user: {e}")
|
||||
|
||||
|
||||
def get_anon_id_from_request(request: Any) -> str | None:
|
||||
"""Extract the anonymous distinct_id from the app PostHog cookie on a request."""
|
||||
if not POSTHOG_API_KEY:
|
||||
return None
|
||||
|
||||
cookie_name = f"ph_{POSTHOG_API_KEY}_posthog"
|
||||
if (cookie_value := request.cookies.get(cookie_name)) and (
|
||||
parsed := parse_posthog_cookie(cookie_value)
|
||||
):
|
||||
return parsed.get("distinct_id")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_marketing_posthog_cookie_name() -> str | None:
|
||||
if not MARKETING_POSTHOG_API_KEY:
|
||||
return None
|
||||
return f"onyx_custom_ph_{MARKETING_POSTHOG_API_KEY}_posthog"
|
||||
|
||||
|
||||
def parse_marketing_cookie(cookie_value: str) -> dict[str, Any] | None:
|
||||
def parse_posthog_cookie(cookie_value: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Parse the URL-encoded JSON marketing cookie.
|
||||
Parse a URL-encoded JSON PostHog cookie
|
||||
|
||||
Expected format (URL-encoded):
|
||||
{"distinct_id":"...", "featureFlags":{"landing_page_variant":"..."}, ...}
|
||||
@@ -102,7 +132,7 @@ def parse_marketing_cookie(cookie_value: str) -> dict[str, Any] | None:
|
||||
cookie_data = json.loads(decoded_cookie)
|
||||
|
||||
distinct_id = cookie_data.get("distinct_id")
|
||||
if not distinct_id:
|
||||
if not distinct_id or not isinstance(distinct_id, str):
|
||||
return None
|
||||
|
||||
return cookie_data
|
||||
|
||||
@@ -135,6 +135,8 @@ from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_alias
|
||||
from onyx.utils.telemetry import mt_cloud_get_anon_id
|
||||
from onyx.utils.telemetry import mt_cloud_identify
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -251,18 +253,12 @@ def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email must be specified"},
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email is not valid"},
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email is not valid")
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -279,12 +275,9 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": REGISTER_INVITE_ONLY_CODE,
|
||||
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
|
||||
},
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.UNAUTHORIZED,
|
||||
"This workspace is invite-only. Please ask your admin to invite you.",
|
||||
)
|
||||
|
||||
|
||||
@@ -294,48 +287,47 @@ def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
verify_email_is_invited(email)
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
def verify_email_domain(email: str, *, is_registration: bool = False) -> None:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email is not valid")
|
||||
|
||||
local_part, domain = email.split("@")
|
||||
domain = domain.lower()
|
||||
local_part = local_part.lower()
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
|
||||
if domain == "googlemail.com":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Please use @gmail.com instead of @googlemail.com.",
|
||||
)
|
||||
|
||||
# Only block dotted Gmail on new signups — existing users must still be
|
||||
# able to sign in with the address they originally registered with.
|
||||
if is_registration and domain == "gmail.com" and "." in local_part:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Gmail addresses with '.' are not allowed. Please use your base email address.",
|
||||
)
|
||||
|
||||
if "+" in local_part and domain != "onyx.app":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
|
||||
},
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Email addresses with '+' are not allowed. Please use your base email address.",
|
||||
)
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
|
||||
},
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if domain not in VALID_EMAIL_DOMAINS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email domain is not valid",
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email domain is not valid")
|
||||
|
||||
|
||||
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
|
||||
@@ -351,7 +343,7 @@ def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
|
||||
)(db_session, seats_needed=seats_needed)
|
||||
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
raise OnyxError(OnyxErrorCode.SEAT_LIMIT_EXCEEDED, result.error_message)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
@@ -404,10 +396,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
captcha_token or "", expected_action="signup"
|
||||
)
|
||||
except CaptchaVerificationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": str(e)},
|
||||
)
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
# We verify the password here to make sure it's valid before we proceed
|
||||
await self.validate_password(
|
||||
@@ -417,13 +406,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Check for disposable emails BEFORE provisioning tenant
|
||||
# This prevents creating tenants for throwaway email addresses
|
||||
try:
|
||||
verify_email_domain(user_create.email)
|
||||
except HTTPException as e:
|
||||
verify_email_domain(user_create.email, is_registration=True)
|
||||
except OnyxError as e:
|
||||
# Log blocked disposable email attempts
|
||||
if (
|
||||
e.status_code == status.HTTP_400_BAD_REQUEST
|
||||
and "Disposable email" in str(e.detail)
|
||||
):
|
||||
if "Disposable email" in e.detail:
|
||||
domain = (
|
||||
user_create.email.split("@")[-1]
|
||||
if "@" in user_create.email
|
||||
@@ -567,9 +553,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_featured.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.is_listed.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
)
|
||||
.order_by(
|
||||
@@ -697,6 +683,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
verify_email_domain(account_email, is_registration=True)
|
||||
|
||||
# Check seat availability before creating (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
@@ -795,6 +783,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
logger.exception("Error deleting anonymous user cookie")
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
# Link the anonymous PostHog session to the identified user so that
|
||||
# pre-login session recordings and events merge into one person profile.
|
||||
if anon_id := mt_cloud_get_anon_id(request):
|
||||
mt_cloud_alias(distinct_id=str(user.id), anonymous_id=anon_id)
|
||||
|
||||
mt_cloud_identify(
|
||||
distinct_id=str(user.id),
|
||||
properties={"email": user.email, "tenant_id": tenant_id},
|
||||
@@ -818,6 +812,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_count = await get_user_count()
|
||||
logger.debug(f"Current tenant user count: {user_count}")
|
||||
|
||||
# Link the anonymous PostHog session to the identified user so
|
||||
# that pre-signup session recordings merge into one person profile.
|
||||
if anon_id := mt_cloud_get_anon_id(request):
|
||||
mt_cloud_alias(distinct_id=str(user.id), anonymous_id=anon_id)
|
||||
|
||||
# Ensure a PostHog person profile exists for this user.
|
||||
mt_cloud_identify(
|
||||
distinct_id=str(user.id),
|
||||
@@ -846,9 +845,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
attribute="get_marketing_posthog_cookie_name",
|
||||
noop_return_value=None,
|
||||
)
|
||||
parse_marketing_cookie = fetch_ee_implementation_or_noop(
|
||||
parse_posthog_cookie = fetch_ee_implementation_or_noop(
|
||||
module="onyx.utils.posthog_client",
|
||||
attribute="parse_marketing_cookie",
|
||||
attribute="parse_posthog_cookie",
|
||||
noop_return_value=None,
|
||||
)
|
||||
capture_and_sync_with_alternate_posthog = fetch_ee_implementation_or_noop(
|
||||
@@ -862,7 +861,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
and user_count is not None
|
||||
and (marketing_cookie_name := get_marketing_posthog_cookie_name())
|
||||
and (marketing_cookie_value := request.cookies.get(marketing_cookie_name))
|
||||
and (parsed_cookie := parse_marketing_cookie(marketing_cookie_value))
|
||||
and (parsed_cookie := parse_posthog_cookie(marketing_cookie_value))
|
||||
):
|
||||
marketing_anonymous_id = parsed_cookie["distinct_id"]
|
||||
|
||||
|
||||
@@ -13,6 +13,14 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -34,6 +42,8 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -48,6 +58,36 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -76,6 +116,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docfetching")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,14 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -35,6 +43,8 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -49,6 +59,36 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -82,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docprocessing")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -90,6 +131,12 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
# Note: worker_process_init only fires in prefork pool mode. Docprocessing uses
|
||||
# worker_pool="threads" (see configs/docprocessing.py), so this handler is
|
||||
# effectively a no-op in normal operation. It remains as a safety net in case
|
||||
# the pool type is ever changed to prefork. Prometheus metrics are safe in
|
||||
# thread-pool mode since all threads share the same process memory and can
|
||||
# update the same Counter/Gauge/Histogram objects directly.
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
@@ -54,8 +54,14 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
# Set by on_worker_init so on_worker_ready knows whether to start the server.
|
||||
_prometheus_collectors_ok: bool = False
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
global _prometheus_collectors_ok
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
@@ -65,6 +71,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
_prometheus_collectors_ok = _setup_prometheus_collectors(sender)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
@@ -72,8 +80,37 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
def _setup_prometheus_collectors(sender: Any) -> bool:
|
||||
"""Register Prometheus collectors that need Redis/DB access.
|
||||
|
||||
Passes the Celery app so the queue depth collector can obtain a fresh
|
||||
broker Redis client on each scrape (rather than holding a stale reference).
|
||||
|
||||
Returns True if registration succeeded, False otherwise.
|
||||
"""
|
||||
try:
|
||||
from onyx.server.metrics.indexing_pipeline_setup import (
|
||||
setup_indexing_pipeline_metrics,
|
||||
)
|
||||
|
||||
setup_indexing_pipeline_metrics(sender.app)
|
||||
logger.info("Prometheus indexing pipeline collectors registered")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to register Prometheus indexing pipeline collectors")
|
||||
return False
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
if _prometheus_collectors_ok:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
start_metrics_server("monitoring")
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping Prometheus metrics server — collector registration failed"
|
||||
)
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -362,7 +362,7 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if not MULTI_TENANT and HOOK_ENABLED:
|
||||
if HOOKS_AVAILABLE:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
|
||||
@@ -30,6 +30,8 @@ from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import plaintext_file_name_for_id
|
||||
from onyx.file_store.utils import store_plaintext
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
@@ -289,6 +291,33 @@ def process_kg_commands(
|
||||
raise KGException("KG setup done")
|
||||
|
||||
|
||||
def _get_or_extract_plaintext(
|
||||
file_id: str,
|
||||
extract_fn: Callable[[], str],
|
||||
) -> str:
|
||||
"""Load cached plaintext for a file, or extract and store it.
|
||||
|
||||
Tries to read pre-stored plaintext from the file store. On a miss,
|
||||
calls extract_fn to produce the text, then stores the result so
|
||||
future calls skip the expensive extraction.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
plaintext_key = plaintext_file_name_for_id(file_id)
|
||||
|
||||
# Try cached plaintext first.
|
||||
try:
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
if content_text:
|
||||
store_plaintext(file_id, content_text)
|
||||
return content_text
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
@@ -303,12 +332,23 @@ def load_chat_file(
|
||||
file_type = ChatFileType(file_descriptor["type"])
|
||||
|
||||
if file_type.is_text_file():
|
||||
try:
|
||||
content_text = extract_file_text(
|
||||
file_id = file_descriptor["id"]
|
||||
|
||||
def _extract() -> str:
|
||||
return extract_file_text(
|
||||
file=file_io,
|
||||
file_name=file_descriptor.get("name") or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
# Use the user_file_id as cache key when available (matches what
|
||||
# the celery indexing worker stores), otherwise fall back to the
|
||||
# file store id (covers code-interpreter-generated files, etc.).
|
||||
user_file_id_str = file_descriptor.get("user_file_id")
|
||||
cache_key = user_file_id_str or file_id
|
||||
|
||||
try:
|
||||
content_text = _get_or_extract_plaintext(cache_key, _extract)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"
|
||||
|
||||
@@ -8,6 +8,7 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -35,7 +36,13 @@ class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| MessageResponseIDInfo
|
||||
| MultiModelMessageResponseIDInfo
|
||||
| StreamingError
|
||||
| CreateChatSessionID
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
|
||||
@@ -177,8 +184,8 @@ class ExtractedContextFiles(BaseModel):
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
project_id_filter: int | None
|
||||
persona_id_filter: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -68,11 +69,19 @@ from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingPayload
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -399,13 +408,13 @@ def determine_search_params(
|
||||
"""
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
persona_id_filter = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
project_id_filter = project_id
|
||||
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
@@ -418,12 +427,34 @@ def determine_search_params(
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_query_processing_hook_result(
|
||||
hook_result: QueryProcessingResponse | HookSkipped | HookSoftFailed,
|
||||
message_text: str,
|
||||
) -> str:
|
||||
"""Apply the Query Processing hook result to the message text.
|
||||
|
||||
Returns the (possibly rewritten) message text, or raises OnyxError with
|
||||
QUERY_REJECTED if the hook signals rejection (query is null or empty).
|
||||
HookSkipped and HookSoftFailed are pass-throughs — the original text is
|
||||
returned unchanged.
|
||||
"""
|
||||
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
|
||||
return message_text
|
||||
if not (hook_result.query and hook_result.query.strip()):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.QUERY_REJECTED,
|
||||
hook_result.rejection_message
|
||||
or "The hook extension for query processing did not return a valid query. No rejection reason was provided.",
|
||||
)
|
||||
return hook_result.query.strip()
|
||||
|
||||
|
||||
def handle_stream_message_objects(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
@@ -474,16 +505,24 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
|
||||
message_text = new_msg_req.message
|
||||
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier, session_id=str(chat_session.id)
|
||||
)
|
||||
@@ -575,6 +614,28 @@ def handle_stream_message_objects(
|
||||
if parent_message.message_type == MessageType.USER:
|
||||
user_message = parent_message
|
||||
else:
|
||||
# New message — run the Query Processing hook before saving to DB.
|
||||
# Skipped on regeneration: the message already exists and was accepted previously.
|
||||
# Skip the hook for empty/whitespace-only messages — no meaningful query
|
||||
# to process, and SendMessageRequest.message has no min_length guard.
|
||||
if message_text.strip():
|
||||
hook_result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload=QueryProcessingPayload(
|
||||
query=message_text,
|
||||
# Pass None for anonymous users or authenticated users without an email
|
||||
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
|
||||
# so None is accepted and serialised as null in both cases.
|
||||
user_email=None if user.is_anonymous else user.email,
|
||||
chat_session_id=str(chat_session.id),
|
||||
).model_dump(),
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
message_text = _resolve_query_processing_hook_result(
|
||||
hook_result, message_text
|
||||
)
|
||||
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=parent_message,
|
||||
@@ -711,8 +772,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
project_id_filter=search_params.project_id_filter,
|
||||
persona_id_filter=search_params.persona_id_filter,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -914,6 +975,17 @@ def handle_stream_message_objects(
|
||||
state_container=state_container,
|
||||
)
|
||||
|
||||
except OnyxError as e:
|
||||
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
|
||||
log_onyx_error(e)
|
||||
yield StreamingError(
|
||||
error=e.detail,
|
||||
error_code=e.error_code.code,
|
||||
is_retryable=e.status_code >= 500,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
|
||||
@@ -44,6 +44,31 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
|
||||
# Hard ceiling for the admin-configurable file upload size (in MB).
|
||||
# Self-hosted customers can raise or lower this via the environment variable.
|
||||
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
|
||||
if _raw_max_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
|
||||
_raw_max_upload_size_mb,
|
||||
)
|
||||
_raw_max_upload_size_mb = 250
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
|
||||
|
||||
# Default fallback for the per-user file upload size limit (in MB) when no
|
||||
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
|
||||
# runtime so this never silently exceeds the hard ceiling.
|
||||
_raw_default_upload_size_mb = int(
|
||||
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
|
||||
)
|
||||
if _raw_default_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
|
||||
_raw_default_upload_size_mb,
|
||||
)
|
||||
_raw_default_upload_size_mb = 100
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
|
||||
) # 1 day
|
||||
@@ -61,17 +86,6 @@ CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -332,6 +346,10 @@ OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
|
||||
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
|
||||
else None
|
||||
)
|
||||
ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH = (
|
||||
os.environ.get("ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -24,11 +24,11 @@ CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
|
||||
LLM_SOCKET_READ_TIMEOUT = int(
|
||||
os.environ.get("LLM_SOCKET_READ_TIMEOUT") or "60"
|
||||
) # 60 seconds
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
# Weighting factor between vector and keyword Search; 1 for completely vector
|
||||
# search, 0 for keyword. Enforces a valid range of [0, 1]. A supplied value from
|
||||
# the env outside of this range will be clipped to the respective end of the
|
||||
# range. Defaults to 0.5.
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
|
||||
HYBRID_ALPHA_KEYWORD = max(
|
||||
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
|
||||
)
|
||||
# Weighting factor between Title and Content of documents during search, 1 for completely
|
||||
# Title based. Default heavily favors Content because Title is also included at the top of
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
|
||||
0
backend/onyx/connectors/canvas/__init__.py
Normal file
0
backend/onyx/connectors/canvas/__init__.py
Normal file
192
backend/onyx/connectors/canvas/client.py
Normal file
192
backend/onyx/connectors/canvas/client.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Requests timeout in seconds.
|
||||
_CANVAS_CALL_TIMEOUT: int = 30
|
||||
_CANVAS_API_VERSION: str = "/api/v1"
|
||||
# Matches the "next" URL in a Canvas Link header, e.g.:
|
||||
# <https://canvas.example.com/api/v1/courses?page=2>; rel="next"
|
||||
# Captures the URL inside the angle brackets.
|
||||
_NEXT_LINK_PATTERN: re.Pattern[str] = re.compile(r'<([^>]+)>;\s*rel="next"')
|
||||
|
||||
|
||||
_STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
|
||||
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
|
||||
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: OnyxErrorCode.BAD_GATEWAY,
|
||||
429: OnyxErrorCode.RATE_LIMITED,
|
||||
}
|
||||
|
||||
|
||||
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
|
||||
"""Map an HTTP status code to the appropriate OnyxErrorCode.
|
||||
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
|
||||
mapped to specific error codes; all other codes (unrecognised 4xx
|
||||
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
|
||||
"""
|
||||
if status_code in _STATUS_TO_ERROR_CODE:
|
||||
return _STATUS_TO_ERROR_CODE[status_code]
|
||||
return OnyxErrorCode.BAD_GATEWAY
|
||||
|
||||
|
||||
class CanvasApiClient:
|
||||
def __init__(
|
||||
self,
|
||||
bearer_token: str,
|
||||
canvas_base_url: str,
|
||||
) -> None:
|
||||
parsed_base = urlparse(canvas_base_url)
|
||||
if not parsed_base.hostname:
|
||||
raise ValueError("canvas_base_url must include a valid host")
|
||||
if parsed_base.scheme != "https":
|
||||
raise ValueError("canvas_base_url must use https")
|
||||
|
||||
self._bearer_token = bearer_token
|
||||
self.base_url = (
|
||||
canvas_base_url.rstrip("/").removesuffix(_CANVAS_API_VERSION)
|
||||
+ _CANVAS_API_VERSION
|
||||
)
|
||||
# Hostname is already validated above; reuse parsed_base instead
|
||||
# of re-parsing. Used by _parse_next_link to validate pagination URLs.
|
||||
self._expected_host: str = parsed_base.hostname
|
||||
|
||||
def get(
|
||||
self,
|
||||
endpoint: str = "",
|
||||
params: dict[str, Any] | None = None,
|
||||
full_url: str | None = None,
|
||||
) -> tuple[Any, str | None]:
|
||||
"""Make a GET request to the Canvas API.
|
||||
|
||||
Returns a tuple of (json_body, next_url).
|
||||
next_url is parsed from the Link header and is None if there are no more pages.
|
||||
If full_url is provided, it is used directly (for following pagination links).
|
||||
|
||||
Security note: full_url must only be set to values returned by
|
||||
``_parse_next_link``, which validates the host against the configured
|
||||
Canvas base URL. Passing an arbitrary URL would leak the bearer token.
|
||||
"""
|
||||
# full_url is used when following pagination (Canvas returns the
|
||||
# next-page URL in the Link header). For the first request we build
|
||||
# the URL from the endpoint name instead.
|
||||
url = full_url if full_url else self._build_url(endpoint)
|
||||
headers = self._build_headers()
|
||||
|
||||
response = rl_requests.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params if not full_url else None,
|
||||
timeout=_CANVAS_CALL_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = response.json()
|
||||
except ValueError as e:
|
||||
if response.status_code < 300:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=f"Invalid JSON in Canvas response: {e}",
|
||||
)
|
||||
logger.warning(
|
||||
"Failed to parse JSON from Canvas error response (status=%d): %s",
|
||||
response.status_code,
|
||||
e,
|
||||
)
|
||||
response_json = {}
|
||||
|
||||
if response.status_code >= 400:
|
||||
# Try to extract the most specific error message from the
|
||||
# Canvas response body. Canvas uses three different shapes
|
||||
# depending on the endpoint and error type:
|
||||
default_error: str = response.reason or f"HTTP {response.status_code}"
|
||||
error = default_error
|
||||
if isinstance(response_json, dict):
|
||||
# Shape 1: {"error": {"message": "Not authorized"}}
|
||||
error_field = response_json.get("error")
|
||||
if isinstance(error_field, dict):
|
||||
response_error = error_field.get("message", "")
|
||||
if response_error:
|
||||
error = response_error
|
||||
# Shape 2: {"error": "Invalid access token"}
|
||||
elif isinstance(error_field, str):
|
||||
error = error_field
|
||||
# Shape 3: {"errors": [{"message": "..."}]}
|
||||
# Used for validation errors. Only use as fallback if
|
||||
# we didn't already find a more specific message above.
|
||||
if error == default_error:
|
||||
errors_list = response_json.get("errors")
|
||||
if isinstance(errors_list, list) and errors_list:
|
||||
first_error = errors_list[0]
|
||||
if isinstance(first_error, dict):
|
||||
msg = first_error.get("message", "")
|
||||
if msg:
|
||||
error = msg
|
||||
raise OnyxError(
|
||||
_error_code_for_status(response.status_code),
|
||||
detail=error,
|
||||
status_code_override=response.status_code,
|
||||
)
|
||||
|
||||
next_url = self._parse_next_link(response.headers.get("Link", ""))
|
||||
return response_json, next_url
|
||||
|
||||
def _parse_next_link(self, link_header: str) -> str | None:
|
||||
"""Extract the 'next' URL from a Canvas Link header.
|
||||
|
||||
Only returns URLs whose host matches the configured Canvas base URL
|
||||
to prevent leaking the bearer token to arbitrary hosts.
|
||||
"""
|
||||
expected_host = self._expected_host
|
||||
for match in _NEXT_LINK_PATTERN.finditer(link_header):
|
||||
url = match.group(1)
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.hostname != expected_host:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=(
|
||||
"Canvas pagination returned an unexpected host "
|
||||
f"({parsed_url.hostname}); expected {expected_host}"
|
||||
),
|
||||
)
|
||||
if parsed_url.scheme != "https":
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=(
|
||||
"Canvas pagination link must use https, "
|
||||
f"got {parsed_url.scheme!r}"
|
||||
),
|
||||
)
|
||||
return url
|
||||
return None
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Return the Authorization header with the bearer token."""
|
||||
return {"Authorization": f"Bearer {self._bearer_token}"}
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
"""Build a full Canvas API URL from an endpoint path.
|
||||
|
||||
Assumes endpoint is non-empty (e.g. ``"courses"``, ``"announcements"``).
|
||||
Only called on a first request, endpoint must be set for first request.
|
||||
Verify endpoint exists in case of future changes where endpoint might be optional.
|
||||
Leading slashes are stripped to avoid double-slash in the result.
|
||||
self.base_url is already normalized with no trailing slash.
|
||||
"""
|
||||
final_url = self.base_url
|
||||
clean_endpoint = endpoint.lstrip("/")
|
||||
if clean_endpoint:
|
||||
final_url += "/" + clean_endpoint
|
||||
return final_url
|
||||
74
backend/onyx/connectors/canvas/connector.py
Normal file
74
backend/onyx/connectors/canvas/connector.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Literal
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
course_code: str
|
||||
created_at: str
|
||||
workflow_state: str
|
||||
|
||||
|
||||
class CanvasPage(BaseModel):
|
||||
page_id: int
|
||||
url: str
|
||||
title: str
|
||||
body: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
course_id: int
|
||||
|
||||
|
||||
class CanvasAssignment(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
html_url: str
|
||||
course_id: int
|
||||
created_at: str
|
||||
updated_at: str
|
||||
due_at: str | None = None
|
||||
|
||||
|
||||
class CanvasAnnouncement(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
message: str | None = None
|
||||
html_url: str
|
||||
posted_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
|
||||
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint state for resumable Canvas indexing.
|
||||
|
||||
Fields:
|
||||
course_ids: Materialized list of course IDs to process.
|
||||
current_course_index: Index into course_ids for current course.
|
||||
stage: Which item type we're processing for the current course.
|
||||
next_url: Pagination cursor within the current stage. None means
|
||||
start from the first page; a URL means resume from that page.
|
||||
|
||||
Invariant:
|
||||
If current_course_index is incremented, stage must be reset to
|
||||
"pages" and next_url must be reset to None.
|
||||
"""
|
||||
|
||||
course_ids: list[int] = []
|
||||
current_course_index: int = 0
|
||||
stage: CanvasStage = "pages"
|
||||
next_url: str | None = None
|
||||
|
||||
def advance_course(self) -> None:
|
||||
"""Move to the next course and reset within-course state."""
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
@@ -890,8 +890,8 @@ class ConfluenceConnector(
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
@@ -915,8 +915,8 @@ class ConfluenceConnector(
|
||||
self.confluence_client, doc_id, restrictions, ancestors
|
||||
) or space_level_access_info.get(page_space_key)
|
||||
|
||||
# Query pages
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
# Query pages (with optional time filtering for indexing_start)
|
||||
page_query = self._construct_page_cql_query(start, end)
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -950,7 +950,9 @@ class ConfluenceConnector(
|
||||
|
||||
# Query attachments for each page
|
||||
page_hierarchy_node_yielded = False
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
|
||||
@@ -123,7 +123,7 @@ class OnyxConfluence:
|
||||
|
||||
self.shared_base_kwargs: dict[str, str | int | bool] = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"backoff_and_retry": False,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
if timeout:
|
||||
@@ -456,7 +456,7 @@ class OnyxConfluence:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
|
||||
@@ -363,7 +363,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
@@ -384,7 +384,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
@@ -408,6 +408,17 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
|
||||
raise e
|
||||
|
||||
if e.response.status_code >= 500:
|
||||
if attempt >= max_retries - 1:
|
||||
raise e
|
||||
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
logger.warning(
|
||||
f"Server error {e.response.status_code}. "
|
||||
f"Retrying in {delay} seconds (attempt {attempt + 1})..."
|
||||
)
|
||||
return math.ceil(time.monotonic() + delay)
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
@@ -239,29 +240,53 @@ def enhanced_search_ids(
|
||||
)
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
def _bulk_fetch_request(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
|
||||
f"be decoded as JSON (response too large or truncated)."
|
||||
)
|
||||
raise
|
||||
|
||||
mid = len(issue_ids) // 2
|
||||
logger.warning(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in response["issues"]
|
||||
for issue in raw_issues
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class NotionPage(BaseModel):
|
||||
id: str
|
||||
created_time: str
|
||||
last_edited_time: str
|
||||
archived: bool
|
||||
in_trash: bool
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
|
||||
@@ -63,6 +63,13 @@ class NotionPage(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class NotionDataSource(BaseModel):
|
||||
"""Represents a Notion Data Source within a database."""
|
||||
|
||||
id: str
|
||||
name: str = ""
|
||||
|
||||
|
||||
class NotionBlock(BaseModel):
|
||||
"""Represents a Notion Block object"""
|
||||
|
||||
@@ -107,7 +114,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
"Notion-Version": "2026-03-11",
|
||||
}
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
@@ -127,6 +134,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
# Maps child page IDs to their containing page ID (discovered in _read_blocks).
|
||||
# Used to resolve block_id parent types to the actual containing page.
|
||||
self._child_page_parent_map: dict[str, str] = {}
|
||||
# Maps data_source_id -> database_id (populated in _read_pages_from_database).
|
||||
# Used to resolve data_source_id parent types back to the database.
|
||||
self._data_source_to_database_map: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@@ -227,7 +237,11 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
|
||||
"""Attempt to fetch a database as a page."""
|
||||
"""Attempt to fetch a database as a page.
|
||||
|
||||
Note: As of API 2025-09-03, database objects no longer include
|
||||
`properties` (schema moved to individual data sources).
|
||||
"""
|
||||
logger.debug(f"Fetching database for ID '{database_id}' as a page")
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
res = rl_requests.get(
|
||||
@@ -246,18 +260,52 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
database_name[0].get("text", {}).get("content") if database_name else None
|
||||
)
|
||||
|
||||
db_data.setdefault("properties", {})
|
||||
|
||||
return NotionPage(**db_data, database_name=database_name)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database(
|
||||
self, database_id: str, cursor: str | None = None
|
||||
def _fetch_data_sources_for_database(
|
||||
self, database_id: str
|
||||
) -> list[NotionDataSource]:
|
||||
"""Fetch the list of data sources for a database."""
|
||||
logger.debug(f"Fetching data sources for database '{database_id}'")
|
||||
res = rl_requests.get(
|
||||
f"https://api.notion.com/v1/databases/{database_id}",
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
if res.status_code in (403, 404):
|
||||
logger.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared "
|
||||
f"with the Onyx integration. Exact exception:\n{e}"
|
||||
)
|
||||
return []
|
||||
logger.exception(f"Error fetching database - {res.json()}")
|
||||
raise e
|
||||
|
||||
db_data = res.json()
|
||||
data_sources = db_data.get("data_sources", [])
|
||||
return [
|
||||
NotionDataSource(id=ds["id"], name=ds.get("name", ""))
|
||||
for ds in data_sources
|
||||
if ds.get("id")
|
||||
]
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_data_source(
|
||||
self, data_source_id: str, cursor: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a database from it's ID via the Notion API."""
|
||||
logger.debug(f"Fetching database for ID '{database_id}'")
|
||||
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
"""Query a data source via POST /v1/data_sources/{id}/query."""
|
||||
logger.debug(f"Querying data source '{data_source_id}'")
|
||||
url = f"https://api.notion.com/v1/data_sources/{data_source_id}/query"
|
||||
body = None if not cursor else {"start_cursor": cursor}
|
||||
res = rl_requests.post(
|
||||
block_url,
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
@@ -265,25 +313,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
json_data = res.json()
|
||||
code = json_data.get("code")
|
||||
# Sep 3 2025 backend changed the error message for this case
|
||||
# TODO: it is also now possible for there to be multiple data sources per database; at present we
|
||||
# just don't handle that. We will need to upgrade the API to the current version + query the
|
||||
# new data sources endpoint to handle that case correctly.
|
||||
if code == "object_not_found" or (
|
||||
code == "validation_error"
|
||||
and "does not contain any data sources" in json_data.get("message", "")
|
||||
):
|
||||
# this happens when a database is not shared with the integration
|
||||
# in this case, we should just ignore the database
|
||||
if res.status_code in (403, 404):
|
||||
logger.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared "
|
||||
f"Unable to access data source with ID '{data_source_id}'. "
|
||||
f"This is likely due to it not being shared "
|
||||
f"with the Onyx integration. Exact exception:\n{e}"
|
||||
)
|
||||
return {"results": [], "next_cursor": None}
|
||||
logger.exception(f"Error fetching database - {res.json()}")
|
||||
logger.exception(f"Error querying data source - {res.json()}")
|
||||
raise e
|
||||
return res.json()
|
||||
|
||||
@@ -348,8 +385,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
# Fallback to workspace if we don't know the parent
|
||||
return self.workspace_id
|
||||
elif parent_type == "data_source_id":
|
||||
# Newer Notion API may use data_source_id for databases
|
||||
return parent.get("database_id") or parent.get("data_source_id")
|
||||
ds_id = parent.get("data_source_id")
|
||||
if ds_id:
|
||||
return self._data_source_to_database_map.get(ds_id, self.workspace_id)
|
||||
elif parent_type in ["page_id", "database_id"]:
|
||||
return parent.get(parent_type)
|
||||
|
||||
@@ -497,18 +535,32 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
if db_node:
|
||||
hierarchy_nodes.append(db_node)
|
||||
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_database(database_id, cursor)
|
||||
# Discover all data sources under this database, then query each one.
|
||||
# Even legacy single-source databases have one entry in the array.
|
||||
data_sources = self._fetch_data_sources_for_database(database_id)
|
||||
if not data_sources:
|
||||
logger.warning(
|
||||
f"Database '{database_id}' returned zero data sources — "
|
||||
f"no pages will be indexed from this database."
|
||||
)
|
||||
for ds in data_sources:
|
||||
self._data_source_to_database_map[ds.id] = database_id
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_data_source(ds.id, cursor)
|
||||
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
text = self._properties_to_str(result.get("properties", {}))
|
||||
if text:
|
||||
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
text = self._properties_to_str(result.get("properties", {}))
|
||||
if text:
|
||||
result_blocks.append(
|
||||
NotionBlock(id=obj_id, text=text, prefix="\n")
|
||||
)
|
||||
|
||||
if not self.recursive_index_enabled:
|
||||
continue
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
if obj_type == "page":
|
||||
logger.debug(
|
||||
f"Found page with ID '{obj_id}' in database '{database_id}'"
|
||||
@@ -518,7 +570,6 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logger.debug(
|
||||
f"Found database with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
# Get nested database name from properties if available
|
||||
nested_db_title = result.get("title", [])
|
||||
nested_db_name = None
|
||||
if nested_db_title and len(nested_db_title) > 0:
|
||||
@@ -533,10 +584,10 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
result_pages.extend(nested_output.child_page_ids)
|
||||
hierarchy_nodes.extend(nested_output.hierarchy_nodes)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return BlockReadOutput(
|
||||
blocks=result_blocks,
|
||||
@@ -807,36 +858,55 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
def _yield_database_hierarchy_nodes(
|
||||
self,
|
||||
) -> Generator[HierarchyNode | Document, None, None]:
|
||||
"""Search for all databases and yield hierarchy nodes for each.
|
||||
"""Search for all data sources and yield hierarchy nodes for their parent databases.
|
||||
|
||||
This must be called BEFORE page indexing so that database hierarchy nodes
|
||||
exist when pages inside databases reference them as parents.
|
||||
|
||||
With the new API, search returns data source objects instead of databases.
|
||||
Multiple data sources can share the same parent database, so we use
|
||||
database_id as the hierarchy node key and deduplicate via
|
||||
_maybe_yield_hierarchy_node.
|
||||
"""
|
||||
query_dict: dict[str, Any] = {
|
||||
"filter": {"property": "object", "value": "database"},
|
||||
"filter": {"property": "object", "value": "data_source"},
|
||||
"page_size": _NOTION_PAGE_SIZE,
|
||||
}
|
||||
pages_seen = 0
|
||||
while pages_seen < _MAX_PAGES:
|
||||
db_res = self._search_notion(query_dict)
|
||||
for db in db_res.results:
|
||||
db_id = db["id"]
|
||||
# Extract title from the title array
|
||||
title_arr = db.get("title", [])
|
||||
db_name = None
|
||||
if title_arr:
|
||||
db_name = " ".join(
|
||||
t.get("plain_text", "") for t in title_arr
|
||||
).strip()
|
||||
if not db_name:
|
||||
for ds in db_res.results:
|
||||
# Extract the parent database_id from the data source's parent
|
||||
ds_parent = ds.get("parent", {})
|
||||
db_id = ds_parent.get("database_id")
|
||||
if not db_id:
|
||||
continue
|
||||
|
||||
# Populate the mapping so _get_parent_raw_id can resolve later
|
||||
ds_id = ds.get("id")
|
||||
if not ds_id:
|
||||
continue
|
||||
self._data_source_to_database_map[ds_id] = db_id
|
||||
|
||||
# Fetch the database to get its actual name and parent
|
||||
try:
|
||||
db_page = self._fetch_database_as_page(db_id)
|
||||
db_name = db_page.database_name or f"Database {db_id}"
|
||||
parent_raw_id = self._get_parent_raw_id(db_page.parent)
|
||||
db_url = (
|
||||
db_page.url or f"https://notion.so/{db_id.replace('-', '')}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(
|
||||
f"Could not fetch database '{db_id}', "
|
||||
f"defaulting to workspace root. Error: {e}"
|
||||
)
|
||||
db_name = f"Database {db_id}"
|
||||
parent_raw_id = self.workspace_id
|
||||
db_url = f"https://notion.so/{db_id.replace('-', '')}"
|
||||
|
||||
# Get parent using existing helper
|
||||
parent_raw_id = self._get_parent_raw_id(db.get("parent"))
|
||||
|
||||
# Notion URLs omit dashes from UUIDs
|
||||
db_url = db.get("url") or f"https://notion.so/{db_id.replace('-', '')}"
|
||||
|
||||
# _maybe_yield_hierarchy_node deduplicates by raw_node_id,
|
||||
# so multiple data sources under one database produce one node.
|
||||
node = self._maybe_yield_hierarchy_node(
|
||||
raw_node_id=db_id,
|
||||
raw_parent_id=parent_raw_id or self.workspace_id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import copy
|
||||
import fnmatch
|
||||
import html
|
||||
import io
|
||||
import os
|
||||
@@ -84,6 +85,44 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
|
||||
|
||||
ASPX_EXTENSION = ".aspx"
|
||||
|
||||
|
||||
def _is_site_excluded(site_url: str, excluded_site_patterns: list[str]) -> bool:
|
||||
"""Check if a site URL matches any of the exclusion glob patterns."""
|
||||
for pattern in excluded_site_patterns:
|
||||
if fnmatch.fnmatch(site_url, pattern) or fnmatch.fnmatch(
|
||||
site_url.rstrip("/"), pattern.rstrip("/")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_path_excluded(item_path: str, excluded_path_patterns: list[str]) -> bool:
|
||||
"""Check if a drive item path matches any of the exclusion glob patterns.
|
||||
|
||||
item_path is the relative path within a drive, e.g. "Engineering/API/report.docx".
|
||||
Matches are attempted against the full path and the filename alone so that
|
||||
patterns like "*.tmp" match files at any depth.
|
||||
"""
|
||||
filename = item_path.rsplit("/", 1)[-1] if "/" in item_path else item_path
|
||||
for pattern in excluded_path_patterns:
|
||||
if fnmatch.fnmatch(item_path, pattern) or fnmatch.fnmatch(filename, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_item_relative_path(parent_reference_path: str | None, item_name: str) -> str:
|
||||
"""Build the relative path of a drive item from its parentReference.path and name.
|
||||
|
||||
Example: parentReference.path="/drives/abc/root:/Eng/API", name="report.docx"
|
||||
=> "Eng/API/report.docx"
|
||||
"""
|
||||
if parent_reference_path and "root:/" in parent_reference_path:
|
||||
folder = unquote(parent_reference_path.split("root:/", 1)[1])
|
||||
if folder:
|
||||
return f"{folder}/{item_name}"
|
||||
return item_name
|
||||
|
||||
|
||||
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
|
||||
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
|
||||
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
|
||||
@@ -478,6 +517,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
include_permissions: bool = False,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
access_token: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
|
||||
if not driveitem.name or not driveitem.id:
|
||||
@@ -610,6 +650,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
drive_item=sdk_item,
|
||||
drive_name=drive_name,
|
||||
add_prefix=True,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
else:
|
||||
external_access = ExternalAccess.empty()
|
||||
@@ -644,6 +685,7 @@ def _convert_sitepage_to_document(
|
||||
graph_client: GraphClient,
|
||||
include_permissions: bool = False,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> Document:
|
||||
"""Convert a SharePoint site page to a Document object."""
|
||||
# Extract text content from the site page
|
||||
@@ -773,6 +815,7 @@ def _convert_sitepage_to_document(
|
||||
graph_client=graph_client,
|
||||
site_page=site_page,
|
||||
add_prefix=True,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
else:
|
||||
external_access = ExternalAccess.empty()
|
||||
@@ -803,6 +846,7 @@ def _convert_driveitem_to_slim_document(
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> SlimDocument:
|
||||
if driveitem.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -813,6 +857,7 @@ def _convert_driveitem_to_slim_document(
|
||||
graph_client=graph_client,
|
||||
drive_item=sdk_item,
|
||||
drive_name=drive_name,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
return SlimDocument(
|
||||
@@ -827,6 +872,7 @@ def _convert_sitepage_to_slim_document(
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> SlimDocument:
|
||||
"""Convert a SharePoint site page to a SlimDocument object."""
|
||||
if site_page.get("id") is None:
|
||||
@@ -836,6 +882,7 @@ def _convert_sitepage_to_slim_document(
|
||||
ctx=ctx,
|
||||
graph_client=graph_client,
|
||||
site_page=site_page,
|
||||
treat_sharing_link_as_public=treat_sharing_link_as_public,
|
||||
)
|
||||
id = site_page.get("id")
|
||||
if id is None:
|
||||
@@ -855,14 +902,20 @@ class SharepointConnector(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
sites: list[str] = [],
|
||||
excluded_sites: list[str] = [],
|
||||
excluded_paths: list[str] = [],
|
||||
include_site_pages: bool = True,
|
||||
include_site_documents: bool = True,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
authority_host: str = DEFAULT_AUTHORITY_HOST,
|
||||
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
|
||||
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sites = list(sites)
|
||||
self.excluded_sites = [s for p in excluded_sites if (s := p.strip())]
|
||||
self.excluded_paths = [s for p in excluded_paths if (s := p.strip())]
|
||||
self.treat_sharing_link_as_public = treat_sharing_link_as_public
|
||||
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
|
||||
sites
|
||||
)
|
||||
@@ -1233,6 +1286,29 @@ class SharepointConnector(
|
||||
break
|
||||
sites = sites._get_next().execute_query()
|
||||
|
||||
def _is_driveitem_excluded(self, driveitem: DriveItemData) -> bool:
|
||||
"""Check if a drive item should be excluded based on excluded_paths patterns."""
|
||||
if not self.excluded_paths:
|
||||
return False
|
||||
relative_path = _build_item_relative_path(
|
||||
driveitem.parent_reference_path, driveitem.name
|
||||
)
|
||||
return _is_path_excluded(relative_path, self.excluded_paths)
|
||||
|
||||
def _filter_excluded_sites(
|
||||
self, site_descriptors: list[SiteDescriptor]
|
||||
) -> list[SiteDescriptor]:
|
||||
"""Remove sites matching any excluded_sites glob pattern."""
|
||||
if not self.excluded_sites:
|
||||
return site_descriptors
|
||||
result = []
|
||||
for sd in site_descriptors:
|
||||
if _is_site_excluded(sd.url, self.excluded_sites):
|
||||
logger.info(f"Excluding site by denylist: {sd.url}")
|
||||
continue
|
||||
result.append(sd)
|
||||
return result
|
||||
|
||||
def fetch_sites(self) -> list[SiteDescriptor]:
|
||||
sites = self.graph_client.sites.get_all_sites().execute_query()
|
||||
|
||||
@@ -1249,7 +1325,7 @@ class SharepointConnector(
|
||||
for site in self._handle_paginated_sites(sites)
|
||||
if "-my.sharepoint" not in site.web_url
|
||||
]
|
||||
return site_descriptors
|
||||
return self._filter_excluded_sites(site_descriptors)
|
||||
|
||||
def _fetch_site_pages(
|
||||
self,
|
||||
@@ -1689,8 +1765,14 @@ class SharepointConnector(
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
def _fetch_slim_documents_from_sharepoint(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
|
||||
# Create a temporary checkpoint for hierarchy node tracking
|
||||
temp_checkpoint = SharepointConnectorCheckpoint(has_more=True)
|
||||
@@ -1708,8 +1790,14 @@ class SharepointConnector(
|
||||
# Process site documents if flag is True
|
||||
if self.include_site_documents:
|
||||
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
|
||||
site_descriptor=site_descriptor
|
||||
site_descriptor=site_descriptor,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
continue
|
||||
|
||||
if drive_web_url:
|
||||
doc_batch.extend(
|
||||
self._yield_drive_hierarchy_node(
|
||||
@@ -1747,6 +1835,7 @@ class SharepointConnector(
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1758,7 +1847,9 @@ class SharepointConnector(
|
||||
|
||||
# Process site pages if flag is True
|
||||
if self.include_site_pages:
|
||||
site_pages = self._fetch_site_pages(site_descriptor)
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start, end=end
|
||||
)
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
@@ -1770,6 +1861,7 @@ class SharepointConnector(
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
@@ -2043,7 +2135,9 @@ class SharepointConnector(
|
||||
and not checkpoint.process_site_pages
|
||||
):
|
||||
logger.info("Initializing SharePoint sites for processing")
|
||||
site_descs = self.site_descriptors or self.fetch_sites()
|
||||
site_descs = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
checkpoint.cached_site_descriptors = deque(site_descs)
|
||||
|
||||
if not checkpoint.cached_site_descriptors:
|
||||
@@ -2264,6 +2358,10 @@ class SharepointConnector(
|
||||
for driveitem in driveitems:
|
||||
item_count += 1
|
||||
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
continue
|
||||
|
||||
if driveitem.id and driveitem.id in checkpoint.seen_document_ids:
|
||||
logger.debug(
|
||||
f"Skipping duplicate document {driveitem.id} ({driveitem.name})"
|
||||
@@ -2318,6 +2416,7 @@ class SharepointConnector(
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
graph_api_base=self.graph_api_base,
|
||||
access_token=access_token,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
if isinstance(doc_or_failure, Document):
|
||||
@@ -2398,6 +2497,7 @@ class SharepointConnector(
|
||||
include_permissions=include_permissions,
|
||||
# Site pages have the site as their parent
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
treat_sharing_link_as_public=self.treat_sharing_link_as_public,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
@@ -2473,12 +2573,22 @@ class SharepointConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
|
||||
yield from self._fetch_slim_documents_from_sharepoint()
|
||||
start_dt = (
|
||||
datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
if start is not None
|
||||
else None
|
||||
)
|
||||
end_dt = (
|
||||
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
|
||||
)
|
||||
yield from self._fetch_slim_documents_from_sharepoint(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,6 +17,7 @@ def get_sharepoint_external_access(
|
||||
drive_name: str | None = None,
|
||||
site_page: dict[str, Any] | None = None,
|
||||
add_prefix: bool = False,
|
||||
treat_sharing_link_as_public: bool = False,
|
||||
) -> ExternalAccess:
|
||||
if drive_item and drive_item.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -34,7 +35,13 @@ def get_sharepoint_external_access(
|
||||
)
|
||||
|
||||
external_access = get_external_access_func(
|
||||
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
|
||||
ctx,
|
||||
graph_client,
|
||||
drive_name,
|
||||
drive_item,
|
||||
site_page,
|
||||
add_prefix,
|
||||
treat_sharing_link_as_public,
|
||||
)
|
||||
|
||||
return external_access
|
||||
|
||||
@@ -516,6 +516,8 @@ def _get_all_doc_ids(
|
||||
] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
workspace_url: str | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -546,6 +548,8 @@ def _get_all_doc_ids(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
oldest=str(start) if start else None, # 0.0 -> None intentionally
|
||||
latest=str(end) if end is not None else None,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
@@ -847,8 +851,8 @@ class SlackConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
@@ -861,6 +865,8 @@ class SlackConnector(
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
callback=callback,
|
||||
workspace_url=self._workspace_url,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _load_from_checkpoint(
|
||||
|
||||
@@ -88,8 +88,9 @@ WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
IFRAME_TEXT_LENGTH_THRESHOLD = 700
|
||||
# Message indicating JavaScript is disabled, which often appears when scraping fails
|
||||
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
|
||||
# Grace period after page navigation to allow bot-detection challenges to complete
|
||||
BOT_DETECTION_GRACE_PERIOD_MS = 5000
|
||||
# Grace period after page navigation to allow bot-detection challenges
|
||||
# and SPA content rendering to complete
|
||||
PAGE_RENDER_TIMEOUT_MS = 5000
|
||||
|
||||
# Define common headers that mimic a real browser
|
||||
DEFAULT_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"
|
||||
@@ -547,7 +548,15 @@ class WebConnector(LoadConnector):
|
||||
)
|
||||
# Give the page a moment to start rendering after navigation commits.
|
||||
# Allows CloudFlare and other bot-detection challenges to complete.
|
||||
page.wait_for_timeout(BOT_DETECTION_GRACE_PERIOD_MS)
|
||||
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
|
||||
|
||||
# Wait for network activity to settle so SPAs that fetch content
|
||||
# asynchronously after the initial JS bundle have time to render.
|
||||
try:
|
||||
# A bit of extra time to account for long-polling, websockets, etc.
|
||||
page.wait_for_load_state("networkidle", timeout=PAGE_RENDER_TIMEOUT_MS)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified") if page_response else None
|
||||
@@ -576,7 +585,7 @@ class WebConnector(LoadConnector):
|
||||
# (e.g., CloudFlare protection keeps making requests)
|
||||
try:
|
||||
page.wait_for_load_state(
|
||||
"networkidle", timeout=BOT_DETECTION_GRACE_PERIOD_MS
|
||||
"networkidle", timeout=PAGE_RENDER_TIMEOUT_MS
|
||||
)
|
||||
except TimeoutError:
|
||||
# If networkidle times out, just give it a moment for content to render
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -70,9 +69,13 @@ class BaseFilters(BaseModel):
|
||||
|
||||
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
# Scopes search to user files tagged with a given project/persona in Vespa.
|
||||
# These are NOT simply the IDs of the current project or persona — they are
|
||||
# only set when the persona's/project's user files overflowed the LLM
|
||||
# context window and must be searched via vector DB instead of being loaded
|
||||
# directly into the prompt.
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
@@ -398,3 +401,16 @@ class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class PersonaSearchInfo(BaseModel):
|
||||
"""Snapshot of persona data needed by the search pipeline.
|
||||
|
||||
Extracted from the ORM Persona before the DB session is released so that
|
||||
SearchTool and search_pipeline never lazy-load relationships post-commit.
|
||||
"""
|
||||
|
||||
document_set_names: list[str]
|
||||
search_start_date: datetime | None
|
||||
attached_document_ids: list[str]
|
||||
hierarchy_node_ids: list[int]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -10,12 +9,12 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.context.search.retrieval.search_runner import search_chunks
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
@@ -39,9 +38,8 @@ logger = setup_logger()
|
||||
def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session | None = None,
|
||||
@@ -97,16 +95,6 @@ def _build_index_filters(
|
||||
if not source_filter and detected_source_filter:
|
||||
source_filter = detected_source_filter
|
||||
|
||||
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
|
||||
# source type is included in the filter, otherwise user files will be excluded!
|
||||
if user_file_ids and source_filter:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# Add user_file to the source filter if not already present
|
||||
if DocumentSource.USER_FILE not in source_filter:
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
@@ -117,9 +105,8 @@ def _build_index_filters(
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
@@ -260,51 +247,41 @@ def search_pipeline(
|
||||
document_index: DocumentIndex,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
# Pre-extracted persona search configuration (None when no persona)
|
||||
persona_search_info: PersonaSearchInfo | None,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Vespa metadata filters for overflowing user files. NOT the raw IDs
|
||||
# of the current project/persona — only set when user files couldn't fit
|
||||
# in the LLM context and need to be searched via vector DB.
|
||||
project_id_filter: int | None = None,
|
||||
persona_id_filter: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
)
|
||||
|
||||
persona_document_sets: list[str] | None = (
|
||||
[persona_document_set.name for persona_document_set in persona.document_sets]
|
||||
if persona
|
||||
else None
|
||||
persona_search_info.document_set_names if persona_search_info else None
|
||||
)
|
||||
persona_time_cutoff: datetime | None = (
|
||||
persona.search_start_date if persona else None
|
||||
persona_search_info.search_start_date if persona_search_info else None
|
||||
)
|
||||
|
||||
# Extract assistant knowledge filters from persona
|
||||
attached_document_ids: list[str] | None = (
|
||||
[doc.id for doc in persona.attached_documents]
|
||||
if persona and persona.attached_documents
|
||||
persona_search_info.attached_document_ids or None
|
||||
if persona_search_info
|
||||
else None
|
||||
)
|
||||
hierarchy_node_ids: list[int] | None = (
|
||||
[node.id for node in persona.hierarchy_nodes]
|
||||
if persona and persona.hierarchy_nodes
|
||||
else None
|
||||
persona_search_info.hierarchy_node_ids or None if persona_search_info else None
|
||||
)
|
||||
|
||||
filters = _build_index_filters(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -14,6 +14,10 @@ from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces_new import DocumentIndex as NewDocumentIndex
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
@@ -49,7 +53,7 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
def _embed_and_search(
|
||||
def _embed_and_hybrid_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session | None = None,
|
||||
@@ -81,6 +85,17 @@ def _embed_and_search(
|
||||
return top_chunks
|
||||
|
||||
|
||||
def _keyword_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: NewDocumentIndex,
|
||||
) -> list[InferenceChunk]:
|
||||
return document_index.keyword_retrieval(
|
||||
query=query_request.query,
|
||||
filters=query_request.filters,
|
||||
num_to_retrieve=query_request.limit or NUM_RETURNED_HITS,
|
||||
)
|
||||
|
||||
|
||||
def search_chunks(
|
||||
query_request: ChunkIndexRequest,
|
||||
user_id: UUID | None,
|
||||
@@ -110,7 +125,6 @@ def search_chunks(
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
@@ -129,21 +143,38 @@ def search_chunks(
|
||||
)
|
||||
|
||||
if normal_search_enabled:
|
||||
run_queries.append(
|
||||
(
|
||||
_embed_and_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
if (
|
||||
query_request.hybrid_alpha is not None
|
||||
and query_request.hybrid_alpha == 0.0
|
||||
and isinstance(document_index, OpenSearchOldDocumentIndex)
|
||||
):
|
||||
# If hybrid alpha is explicitly set to keyword only, do pure keyword
|
||||
# search without generating an embedding. This is currently only
|
||||
# supported with OpenSearchDocumentIndex.
|
||||
opensearch_new_document_index: NewDocumentIndex = document_index._real_index
|
||||
run_queries.append(
|
||||
(
|
||||
lambda: _keyword_search(
|
||||
query_request, opensearch_new_document_index
|
||||
),
|
||||
(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
run_queries.append(
|
||||
(
|
||||
_embed_and_hybrid_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
logger.debug(
|
||||
f"Hybrid search returned no results for query: {query_request.query}with filters: {query_request.filters}"
|
||||
f"Search returned no results for query: {query_request.query} with filters: {query_request.filters}."
|
||||
)
|
||||
return []
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy import Row
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -28,6 +29,7 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@@ -53,9 +55,22 @@ def get_chat_session_by_id(
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
is_shared: bool = False,
|
||||
eager_load_persona: bool = False,
|
||||
) -> ChatSession:
|
||||
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
|
||||
if eager_load_persona:
|
||||
stmt = stmt.options(
|
||||
joinedload(ChatSession.persona).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
),
|
||||
joinedload(ChatSession.project),
|
||||
)
|
||||
|
||||
if is_shared:
|
||||
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
|
||||
else:
|
||||
|
||||
@@ -750,3 +750,31 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -1,4 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class AccountType(str, PyEnum):
|
||||
"""
|
||||
What kind of account this is — determines whether the user
|
||||
enters the group-based permission system.
|
||||
|
||||
STANDARD + SERVICE_ACCOUNT → participate in group system
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@@ -314,3 +341,54 @@ class HookPoint(str, PyEnum):
|
||||
class HookFailStrategy(str, PyEnum):
|
||||
HARD = "hard" # exception propagates, pipeline aborts
|
||||
SOFT = "soft" # log error, return original input, pipeline continues
|
||||
|
||||
|
||||
class Permission(str, PyEnum):
|
||||
"""
|
||||
Permission tokens for group-based authorization.
|
||||
19 tokens total. full_admin_panel_access is an override —
|
||||
if present, any permission check passes.
|
||||
"""
|
||||
|
||||
# Basic (auto-granted to every new group)
|
||||
BASIC_ACCESS = "basic"
|
||||
|
||||
# Read tokens — implied only, never granted directly
|
||||
READ_CONNECTORS = "read:connectors"
|
||||
READ_DOCUMENT_SETS = "read:document_sets"
|
||||
READ_AGENTS = "read:agents"
|
||||
READ_USERS = "read:users"
|
||||
|
||||
# Add / Manage pairs
|
||||
ADD_AGENTS = "add:agents"
|
||||
MANAGE_AGENTS = "manage:agents"
|
||||
MANAGE_DOCUMENT_SETS = "manage:document_sets"
|
||||
ADD_CONNECTORS = "add:connectors"
|
||||
MANAGE_CONNECTORS = "manage:connectors"
|
||||
MANAGE_LLMS = "manage:llms"
|
||||
|
||||
# Toggle tokens
|
||||
READ_AGENT_ANALYTICS = "read:agent_analytics"
|
||||
MANAGE_ACTIONS = "manage:actions"
|
||||
READ_QUERY_HISTORY = "read:query_history"
|
||||
MANAGE_USER_GROUPS = "manage:user_groups"
|
||||
CREATE_USER_API_KEYS = "create:user_api_keys"
|
||||
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
|
||||
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
|
||||
|
||||
# Override — any permission check passes
|
||||
FULL_ADMIN_PANEL_ACCESS = "admin"
|
||||
|
||||
# Permissions that are implied by other grants and must never be stored
|
||||
# directly in the permission_grant table.
|
||||
IMPLIED: ClassVar[frozenset[Permission]]
|
||||
|
||||
|
||||
Permission.IMPLIED = frozenset(
|
||||
{
|
||||
Permission.READ_CONNECTORS,
|
||||
Permission.READ_DOCUMENT_SETS,
|
||||
Permission.READ_AGENTS,
|
||||
Permission.READ_USERS,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -75,6 +75,7 @@ def create_hook__no_commit(
|
||||
fail_strategy: HookFailStrategy,
|
||||
timeout_seconds: float,
|
||||
is_active: bool = False,
|
||||
is_reachable: bool | None = None,
|
||||
creator_id: UUID | None = None,
|
||||
) -> Hook:
|
||||
"""Create a new hook for the given hook point.
|
||||
@@ -100,6 +101,7 @@ def create_hook__no_commit(
|
||||
fail_strategy=fail_strategy,
|
||||
timeout_seconds=timeout_seconds,
|
||||
is_active=is_active,
|
||||
is_reachable=is_reachable,
|
||||
creator_id=creator_id,
|
||||
)
|
||||
# Use a savepoint so that a failed insert only rolls back this operation,
|
||||
|
||||
@@ -2,6 +2,8 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -28,6 +30,9 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
@@ -583,6 +588,67 @@ def get_latest_index_attempt_for_cc_pair_id(
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_latest_successful_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool = False,
|
||||
) -> IndexAttempt | None:
|
||||
"""Returns the most recent successful index attempt for the given cc pair,
|
||||
filtered to the current (or future) search settings.
|
||||
Uses MAX(id) semantics to match get_latest_index_attempts_by_status."""
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS]
|
||||
),
|
||||
)
|
||||
.join(SearchSettings)
|
||||
.where(SearchSettings.status == status)
|
||||
.order_by(desc(IndexAttempt.id))
|
||||
.limit(1)
|
||||
)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_latest_successful_index_attempts_parallel(
|
||||
secondary_index: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
"""Batch version: returns the latest successful index attempt per cc pair.
|
||||
Covers both SUCCESS and COMPLETED_WITH_ERRORS (matching is_successful())."""
|
||||
model_status = (
|
||||
IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
latest_ids = (
|
||||
select(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.id).label("max_id"),
|
||||
)
|
||||
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
.where(
|
||||
SearchSettings.status == model_status,
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS]
|
||||
),
|
||||
)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = select(IndexAttempt).join(
|
||||
latest_ids,
|
||||
(
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== latest_ids.c.connector_credential_pair_id
|
||||
)
|
||||
& (IndexAttempt.id == latest_ids.c.max_id),
|
||||
)
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def count_index_attempts_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
@@ -911,3 +977,106 @@ def get_index_attempt_errors_for_cc_pair(
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActiveIndexAttemptMetric(NamedTuple):
|
||||
"""Row returned by get_active_index_attempts_for_metrics."""
|
||||
|
||||
status: IndexingStatus
|
||||
source: "DocumentSource"
|
||||
cc_pair_id: int
|
||||
cc_pair_name: str | None
|
||||
attempt_count: int
|
||||
|
||||
|
||||
def get_active_index_attempts_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list[ActiveIndexAttemptMetric]:
|
||||
"""Return non-terminal index attempts grouped by status, source, and connector.
|
||||
|
||||
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
|
||||
"""
|
||||
from onyx.db.models import Connector
|
||||
|
||||
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
func.count(),
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(IndexAttempt.status.notin_(terminal_statuses))
|
||||
.group_by(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return [ActiveIndexAttemptMetric(*row) for row in rows]
|
||||
|
||||
|
||||
def get_failed_attempt_counts_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
|
||||
|
||||
When ``since`` is provided, only attempts created after that timestamp
|
||||
are counted. Defaults to the last 90 days to avoid unbounded historical
|
||||
aggregation.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.count(),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.FAILED)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.all()
|
||||
)
|
||||
return {cc_id: count for cc_id, count in rows}
|
||||
|
||||
|
||||
def get_docs_indexed_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
|
||||
|
||||
Only counts attempts with status SUCCESS to avoid inflating counts with
|
||||
partial results from failed attempts. When ``since`` is provided, only
|
||||
attempts created after that timestamp are included.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
query = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
)
|
||||
rows = query.all()
|
||||
return {cc_id: int(total or 0) for cc_id, total in rows}
|
||||
|
||||
@@ -48,6 +48,7 @@ from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import (
|
||||
ANONYMOUS_USER_UUID,
|
||||
@@ -78,6 +79,8 @@ from onyx.db.enums import (
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
MCPServerStatus,
|
||||
Permission,
|
||||
GrantSource,
|
||||
LLMModelFlowType,
|
||||
ThemePreference,
|
||||
DefaultAppMode,
|
||||
@@ -302,6 +305,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
Preferences probably should be in a separate table at some point, but for now
|
||||
@@ -2645,6 +2651,15 @@ class ChatMessage(Base):
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# For multi-model turns: the user message points to which assistant response
|
||||
# was selected as the preferred one to continue the conversation with.
|
||||
preferred_response_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
# The display name of the model that generated this assistant message
|
||||
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# What does this message contain
|
||||
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
@@ -2712,6 +2727,12 @@ class ChatMessage(Base):
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
preferred_response: Mapped["ChatMessage | None"] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[preferred_response_id],
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
# Chat messages only need to know their immediate tool call children
|
||||
# If there are nested tool calls, they are stored in the tool_call_children relationship.
|
||||
tool_calls: Mapped[list["ToolCall"] | None] = relationship(
|
||||
@@ -3114,8 +3135,6 @@ class VoiceProvider(Base):
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
@@ -3467,9 +3486,9 @@ class Persona(Base):
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Featured personas are highlighted in the UI
|
||||
featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# controls whether the persona is listed in user-facing agent lists
|
||||
is_listed: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
# higher priority personas are displayed first, ties are resolved by the ID,
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
@@ -3971,6 +3990,8 @@ class SamlAccount(Base):
|
||||
class User__UserGroup(Base):
|
||||
__tablename__ = "user__user_group"
|
||||
|
||||
__table_args__ = (Index("ix_user__user_group_user_id", "user_id"),)
|
||||
|
||||
is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
@@ -3981,6 +4002,48 @@ class User__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class PermissionGrant(Base):
|
||||
__tablename__ = "permission_grant"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
)
|
||||
granted_by: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
granted_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
)
|
||||
|
||||
group: Mapped["UserGroup"] = relationship(
|
||||
"UserGroup", back_populates="permission_grants"
|
||||
)
|
||||
|
||||
@validates("permission")
|
||||
def _validate_permission(self, _key: str, value: Permission) -> Permission:
|
||||
if value in Permission.IMPLIED:
|
||||
raise ValueError(
|
||||
f"{value!r} is an implied permission and cannot be granted directly"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
__tablename__ = "user_group__connector_credential_pair"
|
||||
|
||||
@@ -4075,6 +4138,8 @@ class UserGroup(Base):
|
||||
is_up_for_deletion: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
# whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Last time a user updated this user group
|
||||
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -4118,6 +4183,9 @@ class UserGroup(Base):
|
||||
accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship(
|
||||
"MCPServer", secondary="mcp_server__user_group", back_populates="user_groups"
|
||||
)
|
||||
permission_grants: Mapped[list["PermissionGrant"]] = relationship(
|
||||
"PermissionGrant", back_populates="group", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Token Rate Limiting
|
||||
|
||||
@@ -50,8 +50,18 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_default_behavior_persona(db_session: Session) -> Persona | None:
|
||||
def get_default_behavior_persona(
|
||||
db_session: Session,
|
||||
eager_load_for_tools: bool = False,
|
||||
) -> Persona | None:
|
||||
stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID)
|
||||
if eager_load_for_tools:
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
@@ -126,7 +136,7 @@ def _add_user_filters(
|
||||
else:
|
||||
# Group the public persona conditions
|
||||
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
||||
Persona.is_visible == True # noqa: E712
|
||||
Persona.is_listed == True # noqa: E712
|
||||
)
|
||||
|
||||
where_clause |= public_condition
|
||||
@@ -260,7 +270,7 @@ def create_update_persona(
|
||||
|
||||
try:
|
||||
# Featured persona validation
|
||||
if create_persona_request.featured:
|
||||
if create_persona_request.is_featured:
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
@@ -300,7 +310,7 @@ def create_update_persona(
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
featured=create_persona_request.featured,
|
||||
is_featured=create_persona_request.is_featured,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
|
||||
@@ -910,11 +920,11 @@ def upsert_persona(
|
||||
uploaded_image_id: str | None = None,
|
||||
icon_name: str | None = None,
|
||||
display_priority: int | None = None,
|
||||
is_visible: bool = True,
|
||||
is_listed: bool = True,
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
featured: bool | None = None,
|
||||
is_featured: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
@@ -1037,13 +1047,13 @@ def upsert_persona(
|
||||
if remove_image or uploaded_image_id:
|
||||
existing_persona.uploaded_image_id = uploaded_image_id
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.is_listed = is_listed
|
||||
existing_persona.search_start_date = search_start_date
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.featured = (
|
||||
featured if featured is not None else existing_persona.featured
|
||||
existing_persona.is_featured = (
|
||||
is_featured if is_featured is not None else existing_persona.is_featured
|
||||
)
|
||||
# Update embedded prompt fields if provided
|
||||
if system_prompt is not None:
|
||||
@@ -1109,9 +1119,9 @@ def upsert_persona(
|
||||
uploaded_image_id=uploaded_image_id,
|
||||
icon_name=icon_name,
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
is_listed=is_listed,
|
||||
search_start_date=search_start_date,
|
||||
featured=(featured if featured is not None else False),
|
||||
is_featured=(is_featured if is_featured is not None else False),
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
hierarchy_nodes=hierarchy_nodes or [],
|
||||
@@ -1158,7 +1168,7 @@ def delete_old_default_personas(
|
||||
|
||||
def update_persona_featured(
|
||||
persona_id: int,
|
||||
featured: bool,
|
||||
is_featured: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1166,13 +1176,13 @@ def update_persona_featured(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.featured = featured
|
||||
persona.is_featured = is_featured
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_visibility(
|
||||
persona_id: int,
|
||||
is_visible: bool,
|
||||
is_listed: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1180,7 +1190,7 @@ def update_persona_visibility(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.is_visible = is_visible
|
||||
persona.is_listed = is_listed
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
@@ -34,9 +35,19 @@ class CategorizedFilesResult(BaseModel):
|
||||
user_files: list[UserFile]
|
||||
rejected_files: list[RejectedFile]
|
||||
id_to_temp_id: dict[str, str]
|
||||
# Filenames that should be stored but not indexed.
|
||||
skip_indexing_filenames: set[str] = Field(default_factory=set)
|
||||
# Allow SQLAlchemy ORM models inside this result container
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@property
|
||||
def indexable_files(self) -> list[UserFile]:
|
||||
return [
|
||||
uf
|
||||
for uf in self.user_files
|
||||
if (uf.name or "") not in self.skip_indexing_filenames
|
||||
]
|
||||
|
||||
|
||||
def build_hashed_file_key(file: UploadFile) -> str:
|
||||
name_prefix = (file.filename or "")[:50]
|
||||
@@ -98,6 +109,7 @@ def create_user_files(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files.skip_indexing,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,6 +135,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files = categorized_files_result.user_files
|
||||
rejected_files = categorized_files_result.rejected_files
|
||||
id_to_temp_id = categorized_files_result.id_to_temp_id
|
||||
indexable_files = categorized_files_result.indexable_files
|
||||
# Trigger per-file processing immediately for the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
for rejected_file in rejected_files:
|
||||
@@ -134,12 +147,12 @@ def upload_files_to_user_files_with_indexing(
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
for user_file in indexable_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
for user_file in indexable_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
@@ -155,6 +168,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files_result.skip_indexing_filenames,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ def create_slack_channel_persona(
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
featured=False,
|
||||
is_featured=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -149,6 +150,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return None
|
||||
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
|
||||
@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
@@ -119,10 +108,10 @@ def upsert_voice_provider(
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ How `IndexFilters` fields combine into the final query filter. Applies to both V
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
| **Knowledge scope** | `document_set`, `attached_document_ids`, `hierarchy_node_ids`, `persona_id_filter` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id_filter` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
@@ -31,12 +31,22 @@ AND time >= cutoff -- if set
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### Primary vs additive triggers
|
||||
|
||||
- **`persona_id_filter`** is a **primary** trigger. A persona with user files IS explicit
|
||||
knowledge, so `persona_id_filter` alone can start a knowledge scope. Note: this is
|
||||
NOT the raw ID of the persona being used — it is only set when the persona's
|
||||
user files overflowed the LLM context window.
|
||||
- **`project_id_filter`** is **additive**. It widens an existing scope to include project
|
||||
files but never restricts on its own — a chat inside a project should still search
|
||||
team knowledge when no other knowledge is attached.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona_id_filter` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
- `project_id_filter` is ignored — it never restricts on its own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
@@ -44,39 +54,40 @@ When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_no
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
-- Only persona user files (overflowed context)
|
||||
AND (personas contains 42)
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
-- Document sets + persona user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
```
|
||||
|
||||
-- User files + project files overflowed
|
||||
### Explicit knowledge + overflowing project files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id_filter` is set (project files overflowed the LLM context window), `project_id_filter` widens the filter:
|
||||
|
||||
```
|
||||
-- Document sets + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
document_sets contains "Engineering"
|
||||
OR user_project contains 7
|
||||
)
|
||||
|
||||
-- Persona user files + project files (won't happen in practice;
|
||||
-- custom personas ignore project files per the precedence rule)
|
||||
AND (
|
||||
personas contains 42
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
### Only project_id_filter (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
@@ -91,11 +102,10 @@ AND (acl contains ...)
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `persona_id_filter` | `personas` | `array<int>` | Persona tag for overflowing user files (**primary** trigger) |
|
||||
| `project_id_filter` | `user_project` | `array<int>` | Project tag for overflowing project files (**additive** only) |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
|
||||
@@ -381,6 +381,47 @@ class HybridCapable(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Runs keyword-only search and returns a list of inference chunks.
|
||||
|
||||
Args:
|
||||
query: User query.
|
||||
filters: Filters for things like permissions, source type, time,
|
||||
etc.
|
||||
num_to_retrieve: Number of highest matching chunks to return.
|
||||
|
||||
Returns:
|
||||
Score-ranked (highest first) list of highest matching chunks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Runs semantic-only search and returns a list of inference chunks.
|
||||
|
||||
Args:
|
||||
query_embedding: Vector representation of the query. Must be of the
|
||||
correct dimensionality for the primary index.
|
||||
filters: Filters for things like permissions, source type, time,
|
||||
etc.
|
||||
num_to_retrieve: Number of highest matching chunks to return.
|
||||
|
||||
Returns:
|
||||
Score-ranked (highest first) list of highest matching chunks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
@@ -17,10 +18,13 @@ from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
|
||||
from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
|
||||
from onyx.server.metrics.opensearch_search import observe_opensearch_search
|
||||
from onyx.server.metrics.opensearch_search import track_opensearch_search_in_progress
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
@@ -255,7 +259,6 @@ class OpenSearchClient(AbstractContextManager):
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
@@ -303,6 +306,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
emit_metrics: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
@@ -314,6 +318,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
self._emit_metrics = emit_metrics
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -833,7 +838,10 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_pipeline_id: str | None,
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
|
||||
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
|
||||
"""Searches the index.
|
||||
|
||||
@@ -851,6 +859,8 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
documentation for more information on search request bodies.
|
||||
search_pipeline_id: The ID of the search pipeline to use. If None,
|
||||
the default search pipeline will be used.
|
||||
search_type: Label for Prometheus metrics. Does not affect search
|
||||
behavior.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error searching the index.
|
||||
@@ -863,21 +873,27 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
)
|
||||
result: dict[str, Any]
|
||||
params = {"phase_took": "true"}
|
||||
if search_pipeline_id:
|
||||
result = self._client.search(
|
||||
index=self._index_name,
|
||||
search_pipeline=search_pipeline_id,
|
||||
body=body,
|
||||
params=params,
|
||||
)
|
||||
else:
|
||||
result = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
ctx = self._get_emit_metrics_context_manager(search_type)
|
||||
t0 = time.perf_counter()
|
||||
with ctx:
|
||||
if search_pipeline_id:
|
||||
result = self._client.search(
|
||||
index=self._index_name,
|
||||
search_pipeline=search_pipeline_id,
|
||||
body=body,
|
||||
params=params,
|
||||
)
|
||||
else:
|
||||
result = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
client_duration_s = time.perf_counter() - t0
|
||||
|
||||
hits, time_took, timed_out, phase_took, profile = (
|
||||
self._get_hits_and_profile_from_search_result(result)
|
||||
)
|
||||
if self._emit_metrics:
|
||||
observe_opensearch_search(search_type, client_duration_s, time_took)
|
||||
self._log_search_result_perf(
|
||||
time_took=time_took,
|
||||
timed_out=timed_out,
|
||||
@@ -913,7 +929,11 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
return search_hits
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search_for_document_ids(self, body: dict[str, Any]) -> list[str]:
|
||||
def search_for_document_ids(
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
|
||||
) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
In order to take advantage of the performance benefits of only returning
|
||||
@@ -930,6 +950,8 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
documentation for more information on search request bodies.
|
||||
TODO(andrei): Make this a more deep interface; callers shouldn't
|
||||
need to know to set _source: False for example.
|
||||
search_type: Label for Prometheus metrics. Does not affect search
|
||||
behavior.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error searching the index.
|
||||
@@ -947,13 +969,19 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
)
|
||||
|
||||
params = {"phase_took": "true"}
|
||||
result: dict[str, Any] = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
ctx = self._get_emit_metrics_context_manager(search_type)
|
||||
t0 = time.perf_counter()
|
||||
with ctx:
|
||||
result: dict[str, Any] = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
client_duration_s = time.perf_counter() - t0
|
||||
|
||||
hits, time_took, timed_out, phase_took, profile = (
|
||||
self._get_hits_and_profile_from_search_result(result)
|
||||
)
|
||||
if self._emit_metrics:
|
||||
observe_opensearch_search(search_type, client_duration_s, time_took)
|
||||
self._log_search_result_perf(
|
||||
time_took=time_took,
|
||||
timed_out=timed_out,
|
||||
@@ -1062,7 +1090,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
f"Body: {get_new_body_without_vectors(body)}\n"
|
||||
f"Search pipeline ID: {search_pipeline_id}\n"
|
||||
f"Phase took: {phase_took}\n"
|
||||
f"Profile: {profile}\n"
|
||||
f"Profile: {json.dumps(profile, indent=2)}\n"
|
||||
)
|
||||
if timed_out:
|
||||
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."
|
||||
@@ -1070,6 +1098,20 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
if raise_on_timeout:
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
def _get_emit_metrics_context_manager(
|
||||
self, search_type: OpenSearchSearchType
|
||||
) -> AbstractContextManager[None]:
|
||||
"""
|
||||
Returns a context manager that tracks in-flight OpenSearch searches via
|
||||
a Gauge if emit_metrics is True, otherwise returns a null context
|
||||
manager.
|
||||
"""
|
||||
return (
|
||||
track_opensearch_search_in_progress(search_type)
|
||||
if self._emit_metrics
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def wait_for_opensearch_with_timeout(
|
||||
wait_interval_s: int = 5,
|
||||
|
||||
@@ -53,6 +53,18 @@ DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
|
||||
|
||||
|
||||
class OpenSearchSearchType(str, Enum):
|
||||
"""Search type label used for Prometheus metrics."""
|
||||
|
||||
HYBRID = "hybrid"
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
RANDOM = "random"
|
||||
ID_RETRIEVAL = "id_retrieval"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HybridSearchSubqueryConfiguration(Enum):
|
||||
TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1
|
||||
# Current default.
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
@@ -900,6 +901,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.ID_RETRIEVAL,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
@@ -923,6 +925,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
@@ -948,9 +952,95 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=normalization_pipeline_name,
|
||||
search_type=OpenSearchSearchType.HYBRID,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have
|
||||
# "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_keyword_search_query(
|
||||
query_text=query,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.KEYWORD,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# TODO(andrei): There is some duplicated logic in this function with
|
||||
# others in this file.
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_semantic_search_query(
|
||||
query_embedding=query_embedding,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.SEMANTIC,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
@@ -980,6 +1070,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.RANDOM,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
|
||||
@@ -3,7 +3,8 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
|
||||
@@ -49,13 +50,21 @@ from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
# of the weights should sum to 1.
|
||||
# See https://docs.opensearch.org/latest/query-dsl/term/terms/.
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY = 65_536
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
TermsQuery: TypeAlias = dict[str, dict[str, list[_T]]]
|
||||
TermQuery: TypeAlias = dict[str, dict[str, dict[str, _T]]]
|
||||
|
||||
|
||||
# TODO(andrei): Turn all magic dictionaries to pydantic models.
|
||||
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
# of the weights should sum to 1.
|
||||
def _get_hybrid_search_normalization_weights() -> list[float]:
|
||||
if (
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
|
||||
@@ -219,9 +228,8 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -286,9 +294,8 @@ class DocumentQuery:
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
project_id_filter=None,
|
||||
persona_id_filter=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -319,6 +326,9 @@ class DocumentQuery:
|
||||
it MUST be supplied in addition to a search pipeline. The results from
|
||||
hybrid search are not meaningful without that step.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
@@ -356,9 +366,8 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -404,12 +413,174 @@ class DocumentQuery:
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# Explain is for scoring breakdowns.
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_hybrid_search_body["explain"] = True
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
def get_keyword_search_query(
|
||||
query_text: str,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final keyword search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the keyword search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final keyword search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
keyword_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
keyword_search_query = (
|
||||
DocumentQuery._get_title_content_combined_keyword_search_query(
|
||||
query_text, search_filters=keyword_search_filters
|
||||
)
|
||||
)
|
||||
|
||||
final_keyword_search_query: dict[str, Any] = {
|
||||
"query": keyword_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
|
||||
final_keyword_search_query["highlight"] = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_keyword_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_keyword_search_query["explain"] = True
|
||||
|
||||
return final_keyword_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_semantic_search_query(
|
||||
query_embedding: list[float],
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final semantic search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
TODO(andrei): There is some duplicated logic in this function with
|
||||
others in this file.
|
||||
|
||||
Args:
|
||||
query_embedding: The vector embedding of the text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the semantic search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final semantic search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
semantic_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
semantic_search_query = (
|
||||
DocumentQuery._get_content_vector_similarity_search_query(
|
||||
query_embedding,
|
||||
vector_candidates=num_hits,
|
||||
search_filters=semantic_search_filters,
|
||||
)
|
||||
)
|
||||
|
||||
final_semantic_search_query: dict[str, Any] = {
|
||||
"query": semantic_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_semantic_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_semantic_search_query["explain"] = True
|
||||
|
||||
return final_semantic_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_random_search_query(
|
||||
tenant_state: TenantState,
|
||||
@@ -433,9 +604,8 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -581,8 +751,9 @@ class DocumentQuery:
|
||||
def _get_content_vector_similarity_search_query(
|
||||
query_vector: list[float],
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
query = {
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
@@ -591,11 +762,19 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
|
||||
"bool": {"filter": search_filters}
|
||||
}
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_title_content_combined_keyword_search_query(
|
||||
query_text: str,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
query = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
@@ -603,8 +782,9 @@ class DocumentQuery:
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"operator": "or",
|
||||
# The title fields are strongly discounted as they are included in the content.
|
||||
# It just acts as a minor boost
|
||||
# The title fields are strongly discounted as
|
||||
# they are included in the content. This just
|
||||
# acts as a minor boost.
|
||||
"boost": 0.1,
|
||||
}
|
||||
}
|
||||
@@ -619,6 +799,9 @@ class DocumentQuery:
|
||||
}
|
||||
},
|
||||
{
|
||||
# Analyzes the query and returns results which match any
|
||||
# of the query's terms. More matches result in higher
|
||||
# scores.
|
||||
"match": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
@@ -628,18 +811,30 @@ class DocumentQuery:
|
||||
}
|
||||
},
|
||||
{
|
||||
# Matches an exact phrase in a specified order.
|
||||
"match_phrase": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
# The number of words permitted between words of
|
||||
# a query phrase and still result in a match.
|
||||
"slop": 1,
|
||||
"boost": 1.5,
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
],
|
||||
# Ensures at least one match subquery from the query is present
|
||||
# in the document. This defaults to 1, unless a filter or must
|
||||
# clause is supplied, in which case it defaults to 0.
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["bool"]["filter"] = search_filters
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
@@ -648,9 +843,8 @@ class DocumentQuery:
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -665,7 +859,14 @@ class DocumentQuery:
|
||||
The "filter" key applies a logical AND operator to its elements, so
|
||||
every subfilter must evaluate to true in order for the document to be
|
||||
retrieved. This function returns a list of such subfilters.
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/.
|
||||
|
||||
TODO(ENG-3874): The terms queries returned by this function can be made
|
||||
more performant for large cardinality sets by sorting the values by
|
||||
their UTF-8 byte order.
|
||||
|
||||
TODO(ENG-3875): This function can take even better advantage of filter
|
||||
caching by grouping "static" filters together into one sub-clause.
|
||||
|
||||
Args:
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
@@ -681,12 +882,12 @@ class DocumentQuery:
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
project_id_filter: If not None, only documents with this project ID
|
||||
in user projects will be retrieved. Additive — only applied
|
||||
when a knowledge scope already exists.
|
||||
persona_id_filter: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved. Primary — creates
|
||||
a knowledge scope on its own.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -703,10 +904,6 @@ class DocumentQuery:
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
attached_document_ids: Document IDs explicitly attached to the
|
||||
assistant. If provided along with hierarchy_node_ids, documents
|
||||
matching EITHER criteria will be retrieved (OR logic).
|
||||
@@ -714,6 +911,14 @@ class DocumentQuery:
|
||||
the assistant. Matches chunks where ancestor_hierarchy_node_ids
|
||||
contains any of these values.
|
||||
|
||||
Raises:
|
||||
ValueError: document_id and attached_document_ids were supplied
|
||||
together. This is not allowed because they operate on the same
|
||||
schema field, and it does not semantically make sense to use
|
||||
them together.
|
||||
ValueError: Too many of one of the collection arguments was
|
||||
supplied.
|
||||
|
||||
Returns:
|
||||
A list of filters to be passed into the "filter" key of a search
|
||||
query.
|
||||
@@ -721,70 +926,156 @@ class DocumentQuery:
|
||||
|
||||
def _get_acl_visibility_filter(
|
||||
access_control_list: list[str],
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, dict[str, list[TermQuery[bool] | TermsQuery[str]] | int]]:
|
||||
"""Returns a filter for the access control list.
|
||||
|
||||
Since this returns an isolated bool should clause, it can be cached
|
||||
in OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
access_control_list: The access control list to restrict
|
||||
documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of access control list entries is greater
|
||||
than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
|
||||
Returns:
|
||||
A filter for the access control list.
|
||||
"""
|
||||
# Logical OR operator on its elements.
|
||||
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
acl_visibility_filter["bool"]["should"].append(
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
|
||||
)
|
||||
for acl in access_control_list:
|
||||
acl_subclause: dict[str, Any] = {
|
||||
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
|
||||
acl_visibility_filter: dict[str, dict[str, Any]] = {
|
||||
"bool": {
|
||||
"should": [{"term": {PUBLIC_FIELD_NAME: {"value": True}}}],
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
if access_control_list:
|
||||
if len(access_control_list) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many access control list entries: {len(access_control_list)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause
|
||||
# because Lucene will optimize the filtering for large sets of
|
||||
# terms. Small sets of terms are not expected to perform any
|
||||
# differently than individual term clauses.
|
||||
acl_subclause: TermsQuery[str] = {
|
||||
"terms": {ACCESS_CONTROL_LIST_FIELD_NAME: list(access_control_list)}
|
||||
}
|
||||
acl_visibility_filter["bool"]["should"].append(acl_subclause)
|
||||
return acl_visibility_filter
|
||||
|
||||
def _get_source_type_filter(
|
||||
source_types: list[DocumentSource],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for source_type in source_types:
|
||||
source_type_filter["bool"]["should"].append(
|
||||
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
|
||||
) -> TermsQuery[str]:
|
||||
"""Returns a filter for the source types.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
source_types: The source types to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of source types is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the source types.
|
||||
"""
|
||||
if not source_types:
|
||||
raise ValueError(
|
||||
"source_types cannot be empty if trying to create a source type filter."
|
||||
)
|
||||
return source_type_filter
|
||||
|
||||
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
tag_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for tag in tags:
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
|
||||
tag_filter["bool"]["should"].append(
|
||||
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
|
||||
if len(source_types) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many source types: {len(source_types)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
return tag_filter
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {
|
||||
"terms": {
|
||||
SOURCE_TYPE_FIELD_NAME: [
|
||||
source_type.value for source_type in source_types
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for document_set in document_sets:
|
||||
document_set_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
|
||||
def _get_tag_filter(tags: list[Tag]) -> TermsQuery[str]:
|
||||
"""Returns a filter for the tags.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
tags: The tags to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of tags is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the tags.
|
||||
"""
|
||||
if not tags:
|
||||
raise ValueError(
|
||||
"tags cannot be empty if trying to create a tag filter."
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
if len(tags) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many tags: {len(tags)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
return user_file_id_filter
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str_list = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in tags
|
||||
]
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {METADATA_LIST_FIELD_NAME: tag_str_list}}
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
user_project_filter["bool"]["should"].append(
|
||||
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
)
|
||||
return user_project_filter
|
||||
def _get_document_set_filter(document_sets: list[str]) -> TermsQuery[str]:
|
||||
"""Returns a filter for the document sets.
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
document_sets: The document sets to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of document sets is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the document sets.
|
||||
"""
|
||||
if not document_sets:
|
||||
raise ValueError(
|
||||
"document_sets cannot be empty if trying to create a document set filter."
|
||||
)
|
||||
if len(document_sets) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many document sets: {len(document_sets)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {DOCUMENT_SETS_FIELD_NAME: list(document_sets)}}
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> TermQuery[int]:
|
||||
return {"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> TermQuery[int]:
|
||||
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
@@ -792,7 +1083,9 @@ class DocumentQuery:
|
||||
# document data.
|
||||
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
|
||||
# Logical OR operator on its elements.
|
||||
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
time_cutoff_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"range": {
|
||||
@@ -827,25 +1120,77 @@ class DocumentQuery:
|
||||
|
||||
def _get_attached_document_id_filter(
|
||||
doc_ids: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""Filter for documents explicitly attached to an assistant."""
|
||||
# Logical OR operator on its elements.
|
||||
doc_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for doc_id in doc_ids:
|
||||
doc_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": doc_id}}}
|
||||
) -> TermsQuery[str]:
|
||||
"""
|
||||
Returns a filter for documents explicitly attached to an assistant.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
doc_ids: The document IDs to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of document IDs is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the document IDs.
|
||||
"""
|
||||
if not doc_ids:
|
||||
raise ValueError(
|
||||
"doc_ids cannot be empty if trying to create a document ID filter."
|
||||
)
|
||||
return doc_id_filter
|
||||
if len(doc_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many document IDs: {len(doc_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {DOCUMENT_ID_FIELD_NAME: list(doc_ids)}}
|
||||
|
||||
def _get_hierarchy_node_filter(
|
||||
node_ids: list[int],
|
||||
) -> dict[str, Any]:
|
||||
"""Filter for chunks whose ancestors include any of the given hierarchy nodes.
|
||||
|
||||
Uses a terms query to check if ancestor_hierarchy_node_ids contains
|
||||
any of the specified node IDs.
|
||||
) -> TermsQuery[int]:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
Returns a filter for chunks whose ancestors include any of the given
|
||||
hierarchy nodes.
|
||||
|
||||
Since this returns an isolated terms clause, it can be cached in
|
||||
OpenSearch independently of other clauses in _get_search_filters.
|
||||
|
||||
Args:
|
||||
node_ids: The hierarchy node IDs to restrict documents to.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of hierarchy node IDs is greater than
|
||||
MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY.
|
||||
ValueError: An empty list was supplied.
|
||||
|
||||
Returns:
|
||||
A filter for the hierarchy node IDs.
|
||||
"""
|
||||
if not node_ids:
|
||||
raise ValueError(
|
||||
"node_ids cannot be empty if trying to create a hierarchy node ID filter."
|
||||
)
|
||||
if len(node_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY:
|
||||
raise ValueError(
|
||||
f"Too many hierarchy node IDs: {len(node_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}."
|
||||
)
|
||||
# Use terms instead of a list of term within a should clause because
|
||||
# Lucene will optimize the filtering for large sets of terms. Small
|
||||
# sets of terms are not expected to perform any differently than
|
||||
# individual term clauses.
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: list(node_ids)}}
|
||||
|
||||
if document_id is not None and attached_document_ids is not None:
|
||||
raise ValueError(
|
||||
"document_id and attached_document_ids cannot be used together."
|
||||
)
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
@@ -876,17 +1221,23 @@ class DocumentQuery:
|
||||
# assistant can see. When none are set the assistant searches
|
||||
# everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user files
|
||||
# findable but must NOT trigger the restriction on their own (an agent
|
||||
# with no explicit knowledge should search everything).
|
||||
# persona_id_filter is a primary trigger — a persona with user files IS
|
||||
# explicit knowledge, so it can start a knowledge scope on its own.
|
||||
#
|
||||
# project_id_filter is additive — it widens the scope to also cover
|
||||
# overflowing project files but never restricts on its own (a chat
|
||||
# inside a project should still search team knowledge).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
or persona_id_filter is not None
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
# Since this returns an isolated bool should clause, it can be
|
||||
# cached in OpenSearch independently of other clauses in
|
||||
# _get_search_filters.
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
@@ -898,23 +1249,17 @@ class DocumentQuery:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user files, but
|
||||
# only when an explicit restriction is already in effect.
|
||||
if project_id is not None:
|
||||
if persona_id_filter is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
_get_persona_filter(persona_id_filter)
|
||||
)
|
||||
if persona_id is not None:
|
||||
if project_id_filter is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
_get_user_project_filter(project_id_filter)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
@@ -932,8 +1277,6 @@ class DocumentQuery:
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
@@ -199,31 +199,29 @@ def build_vespa_filters(
|
||||
]
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
# Knowledge scope: explicit knowledge attachments restrict what an
|
||||
# assistant can see. When none are set, the assistant can see
|
||||
# everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
# persona_id_filter is a primary trigger — a persona with user files IS
|
||||
# explicit knowledge, so it can start a knowledge scope on its own.
|
||||
#
|
||||
# project_id_filter is additive — it widens the scope to also cover
|
||||
# overflowing project files but never restricts on its own (a chat
|
||||
# inside a project should still search team knowledge).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id_filter))
|
||||
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
# project_id_filter only widens an existing scope.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
_append(
|
||||
knowledge_scope_parts,
|
||||
_build_user_project_filter(filters.project_id_filter),
|
||||
)
|
||||
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
|
||||
@@ -610,6 +610,22 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
return cleanup_content_for_chunks(query_vespa(params))
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
|
||||
@@ -44,6 +44,7 @@ class OnyxErrorCode(Enum):
|
||||
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
|
||||
INVALID_INPUT = ("INVALID_INPUT", 400)
|
||||
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
|
||||
QUERY_REJECTED = ("QUERY_REJECTED", 400)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Not Found (404)
|
||||
@@ -88,6 +89,7 @@ class OnyxErrorCode(Enum):
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
|
||||
@@ -38,17 +38,7 @@ def get_federated_retrieval_functions(
|
||||
source_types: list[DocumentSource] | None,
|
||||
document_set_names: list[str] | None,
|
||||
slack_context: SlackContext | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
) -> list[FederatedRetrievalInfo]:
|
||||
# When User Knowledge (user files) is the only knowledge source enabled,
|
||||
# skip federated connectors entirely. User Knowledge mode means the agent
|
||||
# should ONLY use uploaded files, not team connectors like Slack.
|
||||
if user_file_ids and not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping all federated connectors: User Knowledge mode enabled "
|
||||
f"with {len(user_file_ids)} user files and no document sets"
|
||||
)
|
||||
return []
|
||||
|
||||
# Check for Slack bot context first (regardless of user_id)
|
||||
if slack_context:
|
||||
|
||||
@@ -15,6 +15,7 @@ PLAIN_TEXT_MIME_TYPE = "text/plain"
|
||||
class OnyxMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE}
|
||||
TEXT_MIME_TYPES = {
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
@@ -34,13 +35,12 @@ class OnyxMimeTypes:
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES
|
||||
)
|
||||
|
||||
EXCLUDED_IMAGE_TYPES = {
|
||||
|
||||
@@ -13,13 +13,14 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
CSV = "csv"
|
||||
# Tabular data files (CSV, TSV, XLSX) — metadata-only injection
|
||||
TABULAR = "tabular"
|
||||
|
||||
def is_text_file(self) -> bool:
|
||||
return self in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.DOC,
|
||||
ChatFileType.CSV,
|
||||
ChatFileType.TABULAR,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -23,45 +23,55 @@ from onyx.utils.timing import log_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return f"plaintext_{user_file_id}"
|
||||
def plaintext_file_name_for_id(file_id: str) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a file."""
|
||||
return f"plaintext_{file_id}"
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
|
||||
"""
|
||||
Store plaintext content for a user file in the file store.
|
||||
Store plaintext content for a file in the file store.
|
||||
|
||||
Args:
|
||||
user_file_id: The ID of the user file
|
||||
file_id: The ID of the file (user_file or artifact_file)
|
||||
plaintext_content: The plaintext content to store
|
||||
|
||||
Returns:
|
||||
bool: True if storage was successful, False otherwise
|
||||
"""
|
||||
# Skip empty content
|
||||
if not plaintext_content:
|
||||
return False
|
||||
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
plaintext_file_name = plaintext_file_name_for_id(file_id)
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
display_name=f"Plaintext for {file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# --- Convenience wrappers for callers that use user-file UUIDs ---
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return plaintext_file_name_for_id(str(user_file_id))
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
|
||||
return store_plaintext(str(user_file_id), plaintext_content)
|
||||
|
||||
|
||||
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
|
||||
"""Load a file directly from the file store using its file_record ID.
|
||||
|
||||
|
||||
391
backend/onyx/hooks/executor.py
Normal file
391
backend/onyx/hooks/executor.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
"""No active hook configured for this hook point."""
|
||||
|
||||
|
||||
class HookSoftFailed:
|
||||
"""Hook was called but failed with SOFT fail strategy — continuing."""
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
@@ -42,12 +42,8 @@ class HookUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = (
|
||||
None # if None in model_fields_set, reset to spec default
|
||||
)
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None in model_fields_set, reset to spec default
|
||||
fail_strategy: HookFailStrategy | None = None
|
||||
timeout_seconds: float | None = Field(default=None, gt=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_field(self) -> "HookUpdateRequest":
|
||||
@@ -60,6 +56,14 @@ class HookUpdateRequest(BaseModel):
|
||||
and not (self.endpoint_url or "").strip()
|
||||
):
|
||||
raise ValueError("endpoint_url cannot be cleared.")
|
||||
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
|
||||
raise ValueError(
|
||||
"fail_strategy cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
|
||||
raise ValueError(
|
||||
"timeout_seconds cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -87,41 +91,33 @@ class HookResponse(BaseModel):
|
||||
# Nullable to match the DB column — endpoint_url is required on creation but
|
||||
# future hook point types may not use an external endpoint (e.g. built-in handlers).
|
||||
endpoint_url: str | None
|
||||
# Partially-masked API key (e.g. "abcd••••••••wxyz"), or None if no key is set.
|
||||
api_key_masked: str | None
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
is_reachable: bool | None
|
||||
creator_email: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class HookValidateStatus(str, Enum):
|
||||
passed = "passed" # server responded (any status except 401/403)
|
||||
auth_failed = "auth_failed" # server responded with 401 or 403
|
||||
timeout = (
|
||||
"timeout" # TCP connected, but read/write timed out (server exists but slow)
|
||||
)
|
||||
cannot_connect = "cannot_connect" # could not connect to the server
|
||||
|
||||
|
||||
class HookValidateResponse(BaseModel):
|
||||
success: bool
|
||||
status: HookValidateStatus
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookHealthStatus(str, Enum):
|
||||
healthy = "healthy" # green — reachable, no failures in last 1h
|
||||
degraded = "degraded" # yellow — reachable, failures in last 1h
|
||||
unreachable = "unreachable" # red — is_reachable=false or null
|
||||
|
||||
|
||||
class HookFailureRecord(BaseModel):
|
||||
class HookExecutionRecord(BaseModel):
|
||||
error_message: str | None = None
|
||||
status_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class HookHealthResponse(BaseModel):
|
||||
status: HookHealthStatus
|
||||
recent_failures: list[HookFailureRecord] = Field(
|
||||
default_factory=list,
|
||||
description="Last 10 failures, newest first",
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
@@ -13,22 +14,25 @@ _REQUIRED_ATTRS = (
|
||||
"default_timeout_seconds",
|
||||
"fail_hard_description",
|
||||
"default_fail_strategy",
|
||||
"payload_model",
|
||||
"response_model",
|
||||
)
|
||||
|
||||
|
||||
class HookPointSpec(ABC):
|
||||
class HookPointSpec:
|
||||
"""Static metadata and contract for a pipeline hook point.
|
||||
|
||||
This is NOT a regular class meant for direct instantiation by callers.
|
||||
Each concrete subclass represents exactly one hook point and is instantiated
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
|
||||
should ever create instances directly — use get_hook_point_spec() or
|
||||
get_all_specs() from the registry instead.
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
|
||||
get_hook_point_spec() or get_all_specs() from the registry over direct
|
||||
instantiation.
|
||||
|
||||
Each hook point is a concrete subclass of this class. Onyx engineers
|
||||
own these definitions — customers never touch this code.
|
||||
|
||||
Subclasses must define all attributes as class-level constants.
|
||||
payload_model and response_model must be Pydantic BaseModel subclasses;
|
||||
input_schema and output_schema are derived from them automatically.
|
||||
"""
|
||||
|
||||
hook_point: HookPoint
|
||||
@@ -39,21 +43,32 @@ class HookPointSpec(ABC):
|
||||
default_fail_strategy: HookFailStrategy
|
||||
docs_url: str | None = None
|
||||
|
||||
payload_model: ClassVar[type[BaseModel]]
|
||||
response_model: ClassVar[type[BaseModel]]
|
||||
|
||||
# Computed once at class definition time from payload_model / response_model.
|
||||
input_schema: ClassVar[dict[str, Any]]
|
||||
output_schema: ClassVar[dict[str, Any]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Enforce that every subclass declares all required class attributes.
|
||||
|
||||
Called automatically by Python whenever a class inherits from HookPointSpec.
|
||||
Raises TypeError at import time if any required attribute is missing or if
|
||||
payload_model / response_model are not Pydantic BaseModel subclasses.
|
||||
input_schema and output_schema are derived automatically from the models.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Skip intermediate abstract subclasses — they may still be partially defined.
|
||||
if getattr(cls, "__abstractmethods__", None):
|
||||
return
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
if missing:
|
||||
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the request payload sent to the customer's endpoint."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the expected response from the customer's endpoint."""
|
||||
for attr in ("payload_model", "response_model"):
|
||||
val = getattr(cls, attr, None)
|
||||
if val is None or not (
|
||||
isinstance(val, type) and issubclass(val, BaseModel)
|
||||
):
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
|
||||
)
|
||||
cls.input_schema = cls.payload_model.model_json_schema()
|
||||
cls.output_schema = cls.response_model.model_json_schema()
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
# TODO(@Bo-Onyx): define payload and response fields
|
||||
class DocumentIngestionPayload(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionResponse(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
@@ -17,13 +26,8 @@ class DocumentIngestionSpec(HookPointSpec):
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define input schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define output schema
|
||||
return {"type": "object", "properties": {}}
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
|
||||
@@ -1,10 +1,39 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class QueryProcessingPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: str = Field(description="The raw query string exactly as the user typed it.")
|
||||
user_email: str | None = Field(
|
||||
description="Email of the user submitting the query, or null if unauthenticated."
|
||||
)
|
||||
chat_session_id: str = Field(
|
||||
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingResponse(BaseModel):
|
||||
# Intentionally permissive — customer endpoints may return extra fields.
|
||||
query: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The query to use in the pipeline. "
|
||||
"Null, empty string, whitespace-only, or absent = reject the query."
|
||||
),
|
||||
)
|
||||
rejection_message: str | None = Field(
|
||||
default=None,
|
||||
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingSpec(HookPointSpec):
|
||||
"""Hook point that runs on every user query before it enters the pipeline.
|
||||
|
||||
@@ -36,48 +65,8 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The raw query string exactly as the user typed it.",
|
||||
},
|
||||
"user_email": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Email of the user submitting the query, or null if unauthenticated.",
|
||||
},
|
||||
"chat_session_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
|
||||
},
|
||||
},
|
||||
"required": ["query", "user_email", "chat_session_id"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"The (optionally modified) query to use. "
|
||||
"Set to null to reject the query."
|
||||
),
|
||||
},
|
||||
"rejection_message": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Message shown to the user when query is null. "
|
||||
"Falls back to a generic message if not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
|
||||
5
backend/onyx/hooks/utils.py
Normal file
5
backend/onyx/hooks/utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -29,6 +29,7 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -173,8 +174,10 @@ class UserFileIndexingAdapter:
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
token_count: int = (
|
||||
count_tokens(combined_content, llm_tokenizer)
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[str(user_file_id)] = token_count
|
||||
else:
|
||||
|
||||
@@ -25,6 +25,7 @@ class LlmProviderNames(str, Enum):
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -44,6 +45,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
]
|
||||
|
||||
|
||||
@@ -61,6 +63,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -112,6 +115,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -290,6 +290,17 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
self._custom_llm_provider = "openai"
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
model_kwargs["api_base"] = self._api_base
|
||||
|
||||
# This is needed for Ollama to do proper function calling
|
||||
if model_provider == LlmProviderNames.OLLAMA_CHAT and api_base is not None:
|
||||
model_kwargs["api_base"] = api_base
|
||||
@@ -401,14 +412,20 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
model = (
|
||||
f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
)
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
|
||||
# Tool choice
|
||||
if is_claude_model and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
@@ -483,10 +500,11 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if not (is_claude_model or is_ollama or is_mistral):
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included.
|
||||
# so it must be conditionally included unless the request is
|
||||
# routed through Bifrost's OpenAI-compatible endpoint.
|
||||
# Additionally, tool_choice is not supported by Ollama and causes warnings if included.
|
||||
# See also, https://github.com/ollama/ollama/issues/11171
|
||||
optional_kwargs["allowed_openai_params"] = ["tool_choice"]
|
||||
|
||||
@@ -11,6 +11,7 @@ class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
display_name: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user