Compare commits

...

123 Commits

Author SHA1 Message Date
dependabot[bot]
43ed3e2a03 chore(deps-dev): bump types-retry in /backend
Bumps [types-retry](https://github.com/python/typeshed) from 0.9.9.3 to 0.9.9.20250322.
- [Commits](https://github.com/python/typeshed/commits)

---
updated-dependencies:
- dependency-name: types-retry
  dependency-version: 0.9.9.20250322
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-16 18:28:13 +00:00
SubashMohan
ddb14ec762 fix(ui): fix few common ui bugs (#8425) 2026-02-16 11:22:43 +00:00
Justin Tahara
f31f589860 chore(llm): Adding more FE Unit Tests (#8457) 2026-02-16 03:02:19 +00:00
Evan Lohn
63b9a869af fix: CheckpointedConnector pruning only processes first checkpoint step (mirror of #8464) (#8468)
Co-authored-by: Yves Grolet <yves@grolet.com>
2026-02-16 00:45:43 +00:00
Yuhong Sun
6aea36b573 chore: Context summarization update (#8467) 2026-02-15 23:39:47 +00:00
Wenxi
3d8e8d0846 refactor: connector config refresh elements/cleanup (#8428) 2026-02-15 20:12:51 +00:00
Yuhong Sun
dea5be2185 chore: License update (No change, just touchup) (#8460) 2026-02-14 02:44:38 +00:00
Wenxi
d083973d4f chore: disable auto craft animation with feature flag (#8459) 2026-02-14 02:29:37 +00:00
Wenxi
df956888bf fix: bake public recaptcha key in cloud image (#8458) 2026-02-14 02:12:43 +00:00
dependabot[bot]
7c6062e7d5 chore(deps): bump qs from 6.14.1 to 6.14.2 in /web (#8451)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-14 02:04:30 +00:00
Yuhong Sun
89d2759021 chore: Remove end of lived backend routes (#8453) 2026-02-14 01:57:06 +00:00
Justin Tahara
d9feaf43a7 chore(playwright): Adding new LLM Runtime tests (#8447) 2026-02-14 01:38:23 +00:00
Nikolas Garza
5bfffefa2f feat(scim): add SCIM filter expression parser with unit tests (#8421) 2026-02-14 01:17:48 +00:00
Nikolas Garza
4d0b7e14d4 feat(scim): add SCIM PATCH operation handler with unit tests (#8422) 2026-02-14 01:12:46 +00:00
Jamison Lahman
36c55d9e59 chore(gha): de-duplicate integration test logic (#8450) 2026-02-14 00:31:31 +00:00
Wenxi
9f652108f9 fix: don't pass captcha token to db (#8449) 2026-02-14 00:20:36 +00:00
victoria reese
d4e4c6b40e feat: add setting to configure mcp host (#8439) 2026-02-13 23:49:18 +00:00
Jamison Lahman
9c8deb5d0c chore(playwright): mask non-deterministic email element (#8448) 2026-02-13 23:37:24 +00:00
Danelegend
58f57c43aa feat(contextual-llm): Populate and set w/ llm flow (#8398) 2026-02-13 23:32:26 +00:00
Evan Lohn
62106df753 fix: sharepoint cred refresh2 (#8445) 2026-02-13 23:05:15 +00:00
Jamison Lahman
45b3a5e945 chore(playwright): include option to hide element in screenshots (#8446) 2026-02-13 22:45:46 +00:00
Jamison Lahman
e19a6b6789 chore(playwright): create new user tests (#8429) 2026-02-13 22:17:18 +00:00
Jamison Lahman
2de7df4839 chore(playwright): login page screenshots (#8427) 2026-02-13 22:01:32 +00:00
victoria reese
bd054bbad9 fix: remove default idleReplicaCount (#8434) 2026-02-13 13:37:19 -08:00
Justin Tahara
313e709d41 fix(celery): Respecting Limits for Celery Heavy Tasks (#8407) 2026-02-13 21:27:04 +00:00
Nikolas Garza
aeb1d6edac feat(scim): add SCIM 2.0 Pydantic schemas (#8420) 2026-02-13 21:21:05 +00:00
Wenxi
49a35f8aaa fix: remove user file indexing from launch, add init imports for all celery tasks, bump sandbox memory limits (#8443) 2026-02-13 21:15:30 +00:00
Danelegend
049e8ef0e2 feat(llm): Populate env w/ custom config (#8328) 2026-02-13 21:11:49 +00:00
Jamison Lahman
3b61b495a3 chore(playwright): tag appearance_theme tests exclusive (#8441) 2026-02-13 21:07:57 +00:00
Wenxi
5c5c9f0e1d feat(airtable): index all and heirarchy for craft (#8414) 2026-02-13 21:03:53 +00:00
Nikolas Garza
f20d5c33b7 feat(scim): add SCIM database models and migration (#8419) 2026-02-13 20:54:56 +00:00
Jamison Lahman
e898407f7b chore(tests): skip yet another test_web_search_api test (#8442) 2026-02-13 12:50:04 -08:00
Jamison Lahman
f802ff09a7 chore(tests): skip additional web_search test (#8440) 2026-02-13 12:29:36 -08:00
Jamison Lahman
69ad712e09 chore(tests): temporarily disable exa tests (#8431) 2026-02-13 11:06:25 -08:00
Jamison Lahman
98b69c0f2c chore(playwright): welcome_page tests & per-element screenshots (#8426) 2026-02-13 10:07:27 -08:00
Raunak Bhagat
1e5c87896f refactor(web): migrate from usePopup/setPopup to global toast system (#8411)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 17:21:14 +00:00
Raunak Bhagat
b6cc97a8c3 fix(web): icon button and timeline header UI fixes (#8416) 2026-02-13 17:20:37 +00:00
Yuhong Sun
032fbf1058 chore: reminder prompt to be moveable (#8417) 2026-02-13 07:39:12 +00:00
SubashMohan
fc32a9f92a fix(memory): memory tool UI and prompt injection issues (#8377) 2026-02-13 04:29:51 +00:00
Jamison Lahman
9be13bbf63 chore(playwright): make screenshots deterministic (#8412) 2026-02-12 19:53:11 -08:00
Yuhong Sun
9e7176eb82 chore: Tiny intro message change (#8415) 2026-02-12 19:44:34 -08:00
Yuhong Sun
c7faf8ce52 chore: Project instructions would get ignored (#8409) 2026-02-13 02:51:13 +00:00
Jessica Singh
6230e36a63 chore(bulk invite): free trial limit (#8378) 2026-02-13 02:03:38 +00:00
Jamison Lahman
7595b54f6b chore(playwright): upload baselines with merge_group jobs (#8410) 2026-02-13 01:41:14 +00:00
Evan Lohn
dc1bb426ee fix: sharepoint cred refresh (#8406)
Co-authored-by: justin-tahara <justintahara@gmail.com>
2026-02-13 01:38:07 +00:00
acaprau
e9a0506183 chore(opensearch): Add profiling information (#8404)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-13 01:18:16 +00:00
Nikolas Garza
4747c43889 feat(auth): enforce seat limits on all user creation paths (#8401) 2026-02-13 00:18:29 +00:00
Jamison Lahman
27e676c48f chore(devtools): ods screenshot-diff for visual regression testing (#8386)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-13 00:04:22 +00:00
Justin Tahara
6749f63f09 fix(email): Making sure Email Links go to Default Mail Service (#8395) 2026-02-12 23:43:40 +00:00
Yuhong Sun
e404ffd443 fix: SearXNG works now (#8403) 2026-02-12 23:34:02 +00:00
Justin Tahara
c5b89b86c3 fix(ollama): Passing Context Window through (#8385) 2026-02-12 23:30:33 +00:00
Evan Lohn
84bb3867b2 chore: hide file reader (#8402) 2026-02-12 23:12:42 +00:00
Evan Lohn
92cc1d83b5 fix: flaky no vectordb test (#8400) 2026-02-12 23:12:08 +00:00
Justin Tahara
e92d4a342f fix(ollama): Fixing Content Skipping (#8092) 2026-02-12 22:59:56 +00:00
Wenxi
b4d596c957 fix: remove log error when authtype is not set (#8399) 2026-02-12 22:57:13 +00:00
Nikolas Garza
d76d32003b fix(billing): exclude inactive users from seat counts and allow users page when gated (#8397) 2026-02-12 22:51:24 +00:00
Wenxi
007d2d109f feat(craft): pdf preview and refresh output panel (#8392) 2026-02-12 22:41:11 +00:00
Yuhong Sun
08891b5242 fix: Reminders polluting the query expansion (#8391) 2026-02-12 22:30:35 +00:00
Justin Tahara
846672a843 chore(llm): Additional Model Selection Test (#8389) 2026-02-12 21:53:03 +00:00
roshan
0f362457be fix(craft): craft connector FE nits (#8387) 2026-02-12 21:39:33 +00:00
Wenxi
283e8f4d3f feat(craft): pptx generation, editing, preview (#8383)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-12 21:34:22 +00:00
Evan Lohn
fdf19d74bd refactor: github connector (#8384) 2026-02-12 20:33:27 +00:00
roshan
7c702f8932 feat(craft): local file connector (#8304) 2026-02-12 20:23:23 +00:00
Danelegend
3fb06f6e8e feat(search-settings): Add tests + contextual llm validation (#8376) 2026-02-12 20:12:27 +00:00
Jamison Lahman
9fcd999076 chore(devtools): Recommend @playwright/mcp in Cursor (#8380)
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-12 11:15:29 -08:00
Wenxi
c937da65c4 chore: make chatbackgrounds local assets for air-gapped envs (#8381) 2026-02-12 18:53:27 +00:00
Raunak Bhagat
abdbe89dd4 fix: Search submission buttons layouts (#8382) 2026-02-12 18:39:43 +00:00
Raunak Bhagat
54f9c67522 feat: Unified Search and Chat (#8106)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-12 17:53:55 +00:00
Raunak Bhagat
31bcdc69ca refactor(opal): migrate IconButton usages to opal Button (#8333)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 17:19:23 +00:00
Justin Tahara
b748e08029 chore(llm): Adding Tool Enforcement Tests (#8371) 2026-02-12 14:34:27 +00:00
SubashMohan
11b279ad31 feat(memory): enable memory tool to add or update the memory (#8331) 2026-02-12 09:09:00 +00:00
Yuhong Sun
782082f818 chore: Opensearch tuning (#8374)
Co-authored-by: acaprau <48705707+acaprau@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-12 09:03:07 +00:00
acaprau
c01b559bc6 feat(opensearch): Admin configuration 2 - Make the retrieval toggle actually do something (#8370) 2026-02-12 09:01:07 +00:00
acaprau
3101a53855 feat(opensearch): Admin configuration 1 - FE migration tab in the admin sidebar, gated by env var (#8365) 2026-02-12 08:32:36 +00:00
acaprau
ce6c210de1 fix(opensearch): Make chunk migration not stop on an exception; also ACL does not raise (#8375) 2026-02-12 08:29:19 +00:00
acaprau
15b372fea9 feat(opensearch): Admin configuration 0 - REST APIs for migration stuff (#8364) 2026-02-12 06:16:23 +00:00
Nikolas Garza
cf523cb467 feat(ee): gate access only when legacy EE flag is set and no license exists (#8368) 2026-02-12 03:36:27 +00:00
Raunak Bhagat
344625b7e0 fix(opal): add padding to Interactive.Container and smooth foldable transitions (#8367) 2026-02-12 02:05:34 +00:00
Justin Tahara
9bf8400cf8 chore(playwright): Setup LLM Provider (#8362) 2026-02-12 02:03:08 +00:00
Evan Lohn
09e86c2fda fix: no vector db tests (#8369) 2026-02-12 01:37:28 +00:00
Justin Tahara
204328d52a chore(llm): Backend Fallback Logic Tests (#8363) 2026-02-12 01:15:15 +00:00
Nikolas Garza
3ce58c8450 fix(ee): follow HTTP→HTTPS redirects in forward_to_control_plane (#8360) 2026-02-12 00:27:49 +00:00
Evan Lohn
67b5df255a feat: minimal deployment mode (#8293) 2026-02-11 23:56:50 +00:00
Raunak Bhagat
33fa29e19f refactor(opal): rename subvariant to prominence, add internal, remove static (#8348)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:52:44 +00:00
acaprau
787f25a7c8 chore(opensearch): Tuning - Reduce k from 1000 to 50 (#8359) 2026-02-11 23:40:55 +00:00
Justin Tahara
f10b994a27 fix(bedrock): Fixing toolConfig call (#8342) 2026-02-11 23:29:28 +00:00
Danelegend
d4089b1785 chore(search-settings): Remove unused kv search-setting key (#8356) 2026-02-11 23:13:14 +00:00
Jamison Lahman
e122959854 chore(devtools): upgrade ods: 0.5.2->0.5.3 (#8358) 2026-02-11 15:09:25 -08:00
Jamison Lahman
93afb154ee chore(devtools): update ods compose defaults (#8357) 2026-02-11 15:04:16 -08:00
Jamison Lahman
e9be078268 chore(devtools): upgrade ods: 0.5.1->0.5.2 (#8355) 2026-02-11 22:56:32 +00:00
Jamison Lahman
61502751e8 chore(devtools): address missed cubic review (#8353) 2026-02-11 14:40:58 -08:00
Jamison Lahman
cd26893b87 chore(devtools): ods compose defaults ee version (#8351) 2026-02-11 14:35:22 -08:00
Yuhong Sun
90dc6b16fa fix: Metadata file for larger zips (#8327) 2026-02-11 22:16:12 +00:00
Raunak Bhagat
34b48763f4 refactor(opal): update Container height variants, remove paddingVariant (#8350)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 21:53:39 +00:00
Jamison Lahman
094d7a2d02 chore(playwright): remove unnecessary global auth checks (#8341) 2026-02-11 21:45:12 +00:00
victoria reese
faa97e92e8 fix: idleReplicaCount should be optional for ScaledObjects (#8344) 2026-02-11 21:09:45 +00:00
Wenxi
358dc32fd2 fix: upgrade plan page nits (#8346) 2026-02-11 21:06:39 +00:00
Justin Tahara
f06465bfb2 chore(admin): Improve Playwright test speeds (#8326) 2026-02-11 20:12:15 +00:00
Raunak Bhagat
8a51b00050 feat(backend): add default_app_mode field to User table (#8291)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 19:58:53 +00:00
Justin Tahara
33de6dcd6a fix(anthropic): Model Selection in Multi-Tenant (#8308) 2026-02-11 19:34:49 +00:00
acaprau
fe52f4e6d3 chore(opensearch): Add migration queue to helm chart and launch json (#8336) 2026-02-11 19:16:48 +00:00
Jamison Lahman
51de334732 chore(playwright): remove chromatic (#8339) 2026-02-11 18:50:12 +00:00
dependabot[bot]
cb72f84209 chore(deps): bump pillow from 12.0.0 to 12.1.1 (#8338)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-11 18:38:42 +00:00
dependabot[bot]
8b24c08467 chore(deps): bump langchain-core from 0.3.81 to 1.2.11 in /backend/requirements (#8334)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-11 18:28:06 +00:00
Wenxi
0a1e043a97 fix(craft): load messages before restore session and feat: timeout restoration operations (#8303) 2026-02-11 18:09:10 +00:00
Raunak Bhagat
466668fed5 feat(opal): add foldable prop to Button + select-variant icon colour (#8300)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:12:53 +00:00
acaprau
41d105faa0 feat(opensearch): Improved migration task 1 - Completely replace old task logic with new (#8323) 2026-02-11 06:28:59 +00:00
SubashMohan
9e581f48e5 refactor(memory): Refactor memories to use ID-based persistence and new memories UI (#8294) 2026-02-11 06:25:53 +00:00
acaprau
48d8e0955a chore(opensearch): Improved migration task 0 - Schema migrations (#8321) 2026-02-11 04:33:51 +00:00
acaprau
a77780d67e chore(devtools): Add comment in AGENTS.md about the limitations of Celery timeouts with threads (#8257) 2026-02-11 03:38:27 +00:00
Justin Tahara
d13511500c chore(llm): Hardening Fallback Tool Call (#8325) 2026-02-11 02:46:01 +00:00
Wenxi
216d486323 fix: allow basic users to share agents (#8269) 2026-02-11 02:34:07 +00:00
dependabot[bot]
a57d399ba5 chore(deps): bump cryptography from 46.0.3 to 46.0.5 (#8319)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-11 02:28:20 +00:00
Nikolas Garza
07324ae0e4 fix(ee): copy license public key into Docker image (#8322) 2026-02-11 02:16:41 +00:00
Wenxi
c8ae07f7c2 feat(craft): narrow file sync to source, prevent concurrent syncs, and use --delete flag on incremental syncs (#8235) 2026-02-11 02:14:40 +00:00
Justin Tahara
f0fd19f110 chore(llm): Adding new Mock LLM Call test (#8290) 2026-02-11 02:07:21 +00:00
Wenxi
6a62406042 chore(craft): bump sandbox limits one last time TM (#8317) 2026-02-11 01:12:52 +00:00
Jamison Lahman
d0be7dd914 chore(deployment): only try to build desktop if semver-like tag (#8316) 2026-02-11 01:03:19 +00:00
Jamison Lahman
6a045db72b chore(devtools): deploy preview frontend builds in CI (#8315) 2026-02-11 00:58:59 +00:00
Wenxi
e5e9dbe2f0 fix: make /health check async (#8314) 2026-02-11 00:54:34 +00:00
Nikolas Garza
50e0a2cf90 feat(slack): add option to include bot messages during indexing (#8309) 2026-02-10 23:10:59 +00:00
Nikolas Garza
50538ce5ac chore(slack): add logging when bot messages are filtered during indexing (#8305) 2026-02-10 22:41:54 +00:00
Raunak Bhagat
6fab7103bf fix(opal): extract interactive container styles to CSS (#8307)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 21:40:40 +00:00
776 changed files with 55675 additions and 14143 deletions

16
.cursor/mcp.json Normal file
View File

@@ -0,0 +1,16 @@
{
"mcpServers": {
"Playwright": {
"command": "npx",
"args": [
"@playwright/mcp"
]
},
"Linear": {
"url": "https://mcp.linear.app/mcp"
},
"Figma": {
"url": "https://mcp.figma.com/mcp"
}
}
}

View File

@@ -91,8 +91,8 @@ jobs:
BUILD_WEB_CLOUD=true
else
BUILD_WEB=true
# Skip desktop builds on beta tags and nightly runs
if [[ "$IS_BETA" != "true" ]] && [[ "$IS_NIGHTLY" != "true" ]]; then
# Only build desktop for semver tags (excluding beta)
if [[ "$IS_VERSION_TAG" == "true" ]] && [[ "$IS_BETA" != "true" ]]; then
BUILD_DESKTOP=true
fi
fi
@@ -640,6 +640,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
@@ -721,6 +722,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true

View File

@@ -46,6 +46,7 @@ jobs:
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
editions: ${{ steps.set-editions.outputs.editions }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
@@ -56,7 +57,7 @@ jobs:
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" ! -name "no_vectordb" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
@@ -72,6 +73,16 @@ jobs:
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
- name: Determine editions to test
id: set-editions
run: |
# On PRs, only run EE tests. On merge_group and tags, run both EE and MIT.
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo 'editions=["ee"]' >> $GITHUB_OUTPUT
else
echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT
fi
build-backend-image:
runs-on:
[
@@ -267,7 +278,7 @@ jobs:
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
- ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
@@ -275,6 +286,7 @@ jobs:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -298,12 +310,11 @@ jobs:
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
EDITION: ${{ matrix.edition }}
run: |
# Base config shared by both editions
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
@@ -312,11 +323,20 @@ jobs:
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
# EE-only config
if [ "$EDITION" = "ee" ]; then
cat <<EOF >> deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi
- name: Start Docker containers
run: |
@@ -379,14 +399,14 @@ jobs:
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
- name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
@@ -444,10 +464,143 @@ jobs:
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
no-vectordb-tests:
needs: [build-backend-image, build-integration-image]
runs-on:
[
runs-on,
runner=4cpu-linux-arm64,
"run-id=${{ github.run_id }}-no-vectordb-tests",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create .env file for no-vectordb Docker Compose
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
DISABLE_VECTOR_DB=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
EOF
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
- name: Start Docker containers (no-vectordb)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
relational_db \
cache \
minio \
api_server \
background \
-d
id: start_docker_no_vectordb
- name: Wait for services to be ready
run: |
echo "Starting wait-for-service script (no-vectordb)..."
start_time=$(date +%s)
timeout=300
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in $timeout seconds."
exit 1
fi
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "API server is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error; retrying..."
else
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
fi
sleep 5
done
- name: Run No-VectorDB Integration Tests
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running no-vectordb integration tests..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e TEST_WEB_HOSTNAME=test-runner \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/tests/no_vectordb
- name: Dump API server logs (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
- name: Dump all-container logs (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
- name: Upload logs (no-vectordb)
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-no-vectordb
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
- name: Stop Docker containers (no-vectordb)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
multitenant-tests:
needs:
[build-backend-image, build-model-server-image, build-integration-image]
@@ -587,7 +740,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests, multitenant-tests]
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
if: ${{ always() }}
steps:
- name: Check job status

View File

@@ -1,443 +0,0 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
types: [checks_requested]
push:
tags:
- "v*.*.*"
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
all_dirs=""
for dir in $tests_dirs; do
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
done
for dir in $connector_dirs; do
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
done
# Remove trailing comma and wrap in array
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push integration test image with Docker Bake
env:
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
integration-tests-mit:
needs:
[
discover-test-dirs,
build-backend-image,
build-model-server-image,
build-integration-image,
]
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
strategy:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Create .env file for Docker Compose
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
- name: Start Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server \
indexing_model_server \
background \
-d
id: start_docker
- name: Wait for services to be ready
run: |
echo "Starting wait-for-service script..."
wait_for_service() {
local url=$1
local label=$2
local timeout=${3:-300} # default 5 minutes
local start_time
start_time=$(date +%s)
while true; do
local current_time
current_time=$(date +%s)
local elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
exit 1
fi
local response
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
if [ "$response" = "200" ]; then
echo "${label} is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
else
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
}
wait_for_service "http://localhost:8080/health" "API server"
echo "Finished waiting for services."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests-mit]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1

View File

@@ -52,6 +52,9 @@ env:
MCP_SERVER_PUBLIC_HOST: host.docker.internal
MCP_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp
# Visual regression S3 bucket (shared across all jobs)
PLAYWRIGHT_S3_BUCKET: onyx-playwright-artifacts
jobs:
build-web-image:
runs-on:
@@ -239,6 +242,9 @@ jobs:
playwright-tests:
needs: [build-web-image, build-backend-image, build-model-server-image]
name: Playwright Tests (${{ matrix.project }})
permissions:
id-token: write # Required for OIDC-based AWS credential exchange (S3 access)
contents: read
runs-on:
- runs-on
- runner=8cpu-linux-arm64
@@ -428,8 +434,6 @@ jobs:
env:
PROJECT: ${{ matrix.project }}
run: |
# Create test-results directory to ensure it exists for artifact upload
mkdir -p test-results
npx playwright test --project ${PROJECT}
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
@@ -437,9 +441,134 @@ jobs:
with:
# Includes test results and trace.zip files
name: playwright-test-results-${{ matrix.project }}-${{ github.run_id }}
path: ./web/test-results/
path: ./web/output/playwright/
retention-days: 30
- name: Upload screenshots
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: playwright-screenshots-${{ matrix.project }}-${{ github.run_id }}
path: ./web/output/screenshots/
retention-days: 30
# --- Visual Regression Diff ---
- name: Configure AWS credentials
if: always()
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Install the latest version of uv
if: always()
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Determine baseline revision
if: always()
id: baseline-rev
env:
EVENT_NAME: ${{ github.event_name }}
BASE_REF: ${{ github.event.pull_request.base.ref }}
MERGE_GROUP_BASE_REF: ${{ github.event.merge_group.base_ref }}
GH_REF: ${{ github.ref }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ "${EVENT_NAME}" = "pull_request" ]; then
# PRs compare against the base branch (e.g. main, release/2.5)
echo "rev=${BASE_REF}" >> "$GITHUB_OUTPUT"
elif [ "${EVENT_NAME}" = "merge_group" ]; then
# Merge queue compares against the target branch (e.g. refs/heads/main -> main)
echo "rev=${MERGE_GROUP_BASE_REF#refs/heads/}" >> "$GITHUB_OUTPUT"
elif [[ "${GH_REF}" == refs/tags/* ]]; then
# Tag builds compare against the tag name
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
else
# Push builds (main, release/*) compare against the branch name
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
fi
- name: Generate screenshot diff report
if: always()
env:
PROJECT: ${{ matrix.project }}
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
run: |
uv run --no-sync --with onyx-devtools ods screenshot-diff compare \
--project "${PROJECT}" \
--rev "${BASELINE_REV}"
- name: Upload visual diff report to S3
if: always()
env:
PROJECT: ${{ matrix.project }}
PR_NUMBER: ${{ github.event.pull_request.number }}
RUN_ID: ${{ github.run_id }}
run: |
SUMMARY_FILE="web/output/screenshot-diff/${PROJECT}/summary.json"
if [ ! -f "${SUMMARY_FILE}" ]; then
echo "No summary file found — skipping S3 upload."
exit 0
fi
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
if [ "${HAS_DIFF}" != "true" ]; then
echo "No visual differences for ${PROJECT} — skipping S3 upload."
exit 0
fi
aws s3 sync "web/output/screenshot-diff/${PROJECT}/" \
"s3://${PLAYWRIGHT_S3_BUCKET}/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/"
- name: Upload visual diff summary
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: screenshot-diff-summary-${{ matrix.project }}
path: ./web/output/screenshot-diff/${{ matrix.project }}/summary.json
if-no-files-found: ignore
retention-days: 5
- name: Upload visual diff report artifact
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: screenshot-diff-report-${{ matrix.project }}-${{ github.run_id }}
path: ./web/output/screenshot-diff/${{ matrix.project }}/
if-no-files-found: ignore
retention-days: 30
- name: Update S3 baselines
if: >-
success() && (
github.ref == 'refs/heads/main' ||
startsWith(github.ref, 'refs/heads/release/') ||
startsWith(github.ref, 'refs/tags/v') ||
(
github.event_name == 'merge_group' && (
github.event.merge_group.base_ref == 'refs/heads/main' ||
startsWith(github.event.merge_group.base_ref, 'refs/heads/release/')
)
)
)
env:
PROJECT: ${{ matrix.project }}
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
run: |
if [ -d "web/output/screenshots/" ] && [ "$(ls -A web/output/screenshots/)" ]; then
uv run --no-sync --with onyx-devtools ods screenshot-diff upload-baselines \
--project "${PROJECT}" \
--rev "${BASELINE_REV}" \
--delete
else
echo "No screenshots to upload for ${PROJECT} — skipping baseline update."
fi
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
@@ -457,6 +586,95 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
if: always() && github.event_name == 'pull_request'
runs-on: ubuntu-slim
timeout-minutes: 5
permissions:
pull-requests: write
steps:
- name: Download visual diff summaries
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
with:
pattern: screenshot-diff-summary-*
path: summaries/
- name: Post combined PR comment
env:
GH_TOKEN: ${{ github.token }}
PR_NUMBER: ${{ github.event.pull_request.number }}
RUN_ID: ${{ github.run_id }}
REPO: ${{ github.repository }}
S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
run: |
MARKER="<!-- visual-regression-report -->"
# Build the markdown table from all summary files
TABLE_HEADER="| Project | Changed | Added | Removed | Unchanged | Report |"
TABLE_DIVIDER="|---------|---------|-------|---------|-----------|--------|"
TABLE_ROWS=""
HAS_ANY_SUMMARY=false
for SUMMARY_DIR in summaries/screenshot-diff-summary-*/; do
SUMMARY_FILE="${SUMMARY_DIR}summary.json"
if [ ! -f "${SUMMARY_FILE}" ]; then
continue
fi
HAS_ANY_SUMMARY=true
PROJECT=$(jq -r '.project' "${SUMMARY_FILE}")
CHANGED=$(jq -r '.changed' "${SUMMARY_FILE}")
ADDED=$(jq -r '.added' "${SUMMARY_FILE}")
REMOVED=$(jq -r '.removed' "${SUMMARY_FILE}")
UNCHANGED=$(jq -r '.unchanged' "${SUMMARY_FILE}")
TOTAL=$(jq -r '.total' "${SUMMARY_FILE}")
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
if [ "${TOTAL}" = "0" ]; then
REPORT_LINK="_No screenshots_"
elif [ "${HAS_DIFF}" = "true" ]; then
REPORT_URL="https://${S3_BUCKET}.s3.us-east-2.amazonaws.com/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/index.html"
REPORT_LINK="[View Report](${REPORT_URL})"
else
REPORT_LINK="✅ No changes"
fi
TABLE_ROWS="${TABLE_ROWS}| \`${PROJECT}\` | ${CHANGED} | ${ADDED} | ${REMOVED} | ${UNCHANGED} | ${REPORT_LINK} |\n"
done
if [ "${HAS_ANY_SUMMARY}" = "false" ]; then
echo "No visual diff summaries found — skipping PR comment."
exit 0
fi
BODY=$(printf '%s\n' \
"${MARKER}" \
"### 🖼️ Visual Regression Report" \
"" \
"${TABLE_HEADER}" \
"${TABLE_DIVIDER}" \
"$(printf '%b' "${TABLE_ROWS}")")
# Upsert: find existing comment with the marker, or create a new one
EXISTING_COMMENT_ID=$(gh api \
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
--jq ".[] | select(.body | startswith(\"${MARKER}\")) | .id" \
2>/dev/null | head -1)
if [ -n "${EXISTING_COMMENT_ID}" ]; then
gh api \
--method PATCH \
"repos/${REPO}/issues/comments/${EXISTING_COMMENT_ID}" \
-f body="${BODY}"
else
gh api \
--method POST \
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
-f body="${BODY}"
fi
playwright-required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
@@ -467,48 +685,3 @@ jobs:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
# chromatic-tests:
# name: Chromatic Tests
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
# with:
# fetch-depth: 0
# - name: Setup node
# uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
# with:
# node-version: 22
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
# - name: Download Playwright test results
# uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # ratchet:actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results

73
.github/workflows/preview.yml vendored Normal file
View File

@@ -0,0 +1,73 @@
name: Preview Deployment
env:
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
VERCEL_PROJECT_ID: ${{ secrets.VERCEL_PROJECT_ID }}
VERCEL_CLI: vercel@50.14.1
on:
push:
branches-ignore:
- main
paths:
- "web/**"
permissions:
contents: read
pull-requests: write
jobs:
Deploy-Preview:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache-dependency-path: ./web/package-lock.json
- name: Pull Vercel Environment Information
run: npx --yes ${{ env.VERCEL_CLI }} pull --yes --environment=preview --token=${{ secrets.VERCEL_TOKEN }}
- name: Build Project Artifacts
run: npx --yes ${{ env.VERCEL_CLI }} build --token=${{ secrets.VERCEL_TOKEN }}
- name: Deploy Project Artifacts to Vercel
id: deploy
run: |
DEPLOYMENT_URL=$(npx --yes ${{ env.VERCEL_CLI }} deploy --prebuilt --token=${{ secrets.VERCEL_TOKEN }})
echo "url=$DEPLOYMENT_URL" >> "$GITHUB_OUTPUT"
- name: Update PR comment with deployment URL
if: always() && steps.deploy.outputs.url
env:
GH_TOKEN: ${{ github.token }}
DEPLOYMENT_URL: ${{ steps.deploy.outputs.url }}
run: |
# Find the PR for this branch
PR_NUMBER=$(gh pr list --head "$GITHUB_REF_NAME" --json number --jq '.[0].number')
if [ -z "$PR_NUMBER" ]; then
echo "No open PR found for branch $GITHUB_REF_NAME, skipping comment."
exit 0
fi
COMMENT_MARKER="<!-- preview-deployment -->"
COMMENT_BODY="$COMMENT_MARKER
**Preview Deployment**
| Status | Preview | Commit | Updated |
| --- | --- | --- | --- |
| ✅ | $DEPLOYMENT_URL | \`${GITHUB_SHA::7}\` | $(date -u '+%Y-%m-%d %H:%M:%S UTC') |"
# Find existing comment by marker
EXISTING_COMMENT_ID=$(gh api "repos/$GITHUB_REPOSITORY/issues/$PR_NUMBER/comments" \
--jq ".[] | select(.body | startswith(\"$COMMENT_MARKER\")) | .id" | head -1)
if [ -n "$EXISTING_COMMENT_ID" ]; then
gh api "repos/$GITHUB_REPOSITORY/issues/comments/$EXISTING_COMMENT_ID" \
--method PATCH --field body="$COMMENT_BODY"
else
gh pr comment "$PR_NUMBER" --body "$COMMENT_BODY"
fi

290
.github/workflows/sandbox-deployment.yml vendored Normal file
View File

@@ -0,0 +1,290 @@
name: Build and Push Sandbox Image on Tag
on:
push:
tags:
- "experimental-cc4a.*"
# Restrictive defaults; jobs declare what they need.
permissions: {}
jobs:
check-sandbox-changes:
runs-on: ubuntu-slim
timeout-minutes: 10
permissions:
contents: read
outputs:
sandbox-changed: ${{ steps.check.outputs.sandbox-changed }}
new-version: ${{ steps.version.outputs.new-version }}
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
fetch-depth: 0
- name: Check for sandbox-relevant file changes
id: check
run: |
# Get the previous tag to diff against
CURRENT_TAG="${GITHUB_REF_NAME}"
PREVIOUS_TAG=$(git tag --sort=-creatordate | grep '^experimental-cc4a\.' | grep -v "^${CURRENT_TAG}$" | head -n 1)
if [ -z "$PREVIOUS_TAG" ]; then
echo "No previous experimental-cc4a tag found, building unconditionally"
echo "sandbox-changed=true" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "Comparing ${PREVIOUS_TAG}..${CURRENT_TAG}"
# Check if any sandbox-relevant files changed
SANDBOX_PATHS=(
"backend/onyx/server/features/build/sandbox/"
)
CHANGED=false
for path in "${SANDBOX_PATHS[@]}"; do
if git diff --name-only "${PREVIOUS_TAG}..${CURRENT_TAG}" -- "$path" | grep -q .; then
echo "Changes detected in: $path"
CHANGED=true
break
fi
done
echo "sandbox-changed=$CHANGED" >> "$GITHUB_OUTPUT"
- name: Determine new sandbox version
id: version
if: steps.check.outputs.sandbox-changed == 'true'
run: |
# Query Docker Hub for the latest versioned tag
LATEST_TAG=$(curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags?page_size=100" \
| jq -r '.results[].name' \
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
| sort -V \
| tail -n 1)
if [ -z "$LATEST_TAG" ]; then
echo "No existing version tags found on Docker Hub, starting at 0.1.1"
NEW_VERSION="0.1.1"
else
CURRENT_VERSION="${LATEST_TAG#v}"
echo "Latest version on Docker Hub: $CURRENT_VERSION"
# Increment patch version
MAJOR=$(echo "$CURRENT_VERSION" | cut -d. -f1)
MINOR=$(echo "$CURRENT_VERSION" | cut -d. -f2)
PATCH=$(echo "$CURRENT_VERSION" | cut -d. -f3)
NEW_PATCH=$((PATCH + 1))
NEW_VERSION="${MAJOR}.${MINOR}.${NEW_PATCH}"
fi
echo "New version: $NEW_VERSION"
echo "new-version=$NEW_VERSION" >> "$GITHUB_OUTPUT"
build-sandbox-amd64:
needs: check-sandbox-changes
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-sandbox-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
permissions:
contents: read
id-token: write
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/sandbox
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
build-sandbox-arm64:
needs: check-sandbox-changes
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-sandbox-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
permissions:
contents: read
id-token: write
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/sandbox
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: |
type=inline
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
merge-sandbox:
needs:
- check-sandbox-changes
- build-sandbox-amd64
- build-sandbox-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-sandbox
- extras=ecr-cache
timeout-minutes: 30
environment: release
permissions:
id-token: write
env:
REGISTRY_IMAGE: onyxdotapp/sandbox
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=v${{ needs.check-sandbox-changes.outputs.new-version }}
type=raw,value=latest
- name: Create and push manifest
env:
IMAGE_REPO: ${{ env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-sandbox-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-sandbox-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES

1
.gitignore vendored
View File

@@ -6,6 +6,7 @@
!/.vscode/tasks.template.jsonc
.zed
.cursor
!/.cursor/mcp.json
# macos
.DS_store

6
.vscode/launch.json vendored
View File

@@ -246,7 +246,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup,opensearch_migration"
],
"presentation": {
"group": "2"
@@ -275,7 +275,7 @@
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete"
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
],
"presentation": {
"group": "2"
@@ -419,7 +419,7 @@
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
"connector_doc_fetching"
],
"presentation": {
"group": "2"

View File

@@ -144,6 +144,10 @@ function.
If you make any updates to a celery worker and you want to test these changes, you will need
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
**Task Time Limits**:
Since all tasks are executed in thread pools, the time limit features of Celery are silently
disabled and won't work. Timeout logic must be implemented within the task itself.
### Code Quality
```bash

View File

@@ -2,7 +2,10 @@ Copyright (c) 2023-present DanswerAI, Inc.
Portions of this software are licensed as follows:
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
- backend/ee/LICENSE
- web/src/app/ee/LICENSE
- web/src/ee/LICENSE
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.

View File

@@ -134,6 +134,7 @@ COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants
COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY --chown=onyx:onyx ./static /app/static
COPY --chown=onyx:onyx ./keys /app/keys
# Escape hatch scripts
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging

View File

@@ -474,7 +474,7 @@ def run_migrations_online() -> None:
if connectable is not None:
# pytest-alembic is providing an engine - use it directly
logger.info("run_migrations_online starting (pytest-alembic mode).")
logger.debug("run_migrations_online starting (pytest-alembic mode).")
# For pytest-alembic, we use the default schema (public)
schema_name = context.config.attributes.get(

View File

@@ -0,0 +1,33 @@
"""add default_app_mode to user
Revision ID: 114a638452db
Revises: feead2911109
Create Date: 2026-02-09 18:57:08.274640
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "114a638452db"
down_revision = "feead2911109"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"default_app_mode",
sa.String(),
nullable=False,
server_default="CHAT",
),
)
def downgrade() -> None:
op.drop_column("user", "default_app_mode")

View File

@@ -11,7 +11,6 @@ import sqlalchemy as sa
from urllib.parse import urlparse, urlunparse
from httpx import HTTPStatusError
import httpx
from onyx.document_index.factory import get_default_document_index
from onyx.db.search_settings import SearchSettings
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
@@ -519,15 +518,11 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
def upgrade() -> None:
if SKIP_CANON_DRIVE_IDS:
return
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings,
future_search_settings,
)
current_search_settings, _ = active_search_settings()
# Get the index name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
if hasattr(current_search_settings, "index_name"):
index_name = current_search_settings.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"

View File

@@ -0,0 +1,71 @@
"""Migrate to contextual rag model
Revision ID: 19c0ccb01687
Revises: 9c54986124c6
Create Date: 2026-02-12 11:21:41.798037
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "19c0ccb01687"
down_revision = "9c54986124c6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10)
# when the table was created with only CHAT/VISION values.
op.alter_column(
"llm_model_flow",
"llm_model_flow_type",
type_=sa.String(length=20),
existing_type=sa.String(length=10),
existing_nullable=False,
)
# For every search_settings row that has contextual rag configured,
# create an llm_model_flow entry. is_default is TRUE if the row
# belongs to the PRESENT search settings, FALSE otherwise.
op.execute(
"""
INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default)
SELECT DISTINCT
'CONTEXTUAL_RAG',
mc.id,
(ss.status = 'PRESENT')
FROM search_settings ss
JOIN llm_provider lp
ON lp.name = ss.contextual_rag_llm_provider
JOIN model_configuration mc
ON mc.llm_provider_id = lp.id
AND mc.name = ss.contextual_rag_llm_name
WHERE ss.enable_contextual_rag = TRUE
AND ss.contextual_rag_llm_name IS NOT NULL
AND ss.contextual_rag_llm_provider IS NOT NULL
ON CONFLICT (llm_model_flow_type, model_configuration_id)
DO UPDATE SET is_default = EXCLUDED.is_default
WHERE EXCLUDED.is_default = TRUE
"""
)
def downgrade() -> None:
op.execute(
"""
DELETE FROM llm_model_flow
WHERE llm_model_flow_type = 'CONTEXTUAL_RAG'
"""
)
op.alter_column(
"llm_model_flow",
"llm_model_flow_type",
type_=sa.String(length=10),
existing_type=sa.String(length=20),
existing_nullable=False,
)

View File

@@ -16,7 +16,6 @@ from typing import Generator
from alembic import op
import sqlalchemy as sa
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.db.search_settings import SearchSettings
from onyx.configs.app_configs import AUTH_TYPE
@@ -126,14 +125,11 @@ def remove_old_tags() -> None:
the document got reindexed, the old tag would not be removed.
This function removes those old tags by comparing it against the tags in vespa.
"""
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings, future_search_settings
)
current_search_settings, _ = active_search_settings()
# Get the index name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
if hasattr(current_search_settings, "index_name"):
index_name = current_search_settings.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"

View File

@@ -0,0 +1,43 @@
"""add chunk error and vespa count columns to opensearch tenant migration
Revision ID: 93c15d6a6fbb
Revises: d3fd499c829c
Create Date: 2026-02-11 23:07:34.576725
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "93c15d6a6fbb"
down_revision = "d3fd499c829c"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"total_chunks_errored",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"total_chunks_in_vespa",
sa.Integer(),
nullable=False,
server_default="0",
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "total_chunks_in_vespa")
op.drop_column("opensearch_tenant_migration_record", "total_chunks_errored")

View File

@@ -0,0 +1,124 @@
"""add_scim_tables
Revision ID: 9c54986124c6
Revises: b51c6844d1df
Create Date: 2026-02-12 20:29:47.448614
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c54986124c6"
down_revision = "b51c6844d1df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"scim_token",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("hashed_token", sa.String(length=64), nullable=False),
sa.Column("token_display", sa.String(), nullable=False),
sa.Column(
"created_by_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("hashed_token"),
)
op.create_table(
"scim_group_mapping",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_group_id"], ["user_group.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_group_id"),
)
op.create_index(
op.f("ix_scim_group_mapping_external_id"),
"scim_group_mapping",
["external_id"],
unique=True,
)
op.create_table(
"scim_user_mapping",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id"),
)
op.create_index(
op.f("ix_scim_user_mapping_external_id"),
"scim_user_mapping",
["external_id"],
unique=True,
)
def downgrade() -> None:
op.drop_index(
op.f("ix_scim_user_mapping_external_id"),
table_name="scim_user_mapping",
)
op.drop_table("scim_user_mapping")
op.drop_index(
op.f("ix_scim_group_mapping_external_id"),
table_name="scim_group_mapping",
)
op.drop_table("scim_group_mapping")
op.drop_table("scim_token")

View File

@@ -0,0 +1,81 @@
"""seed_memory_tool and add enable_memory_tool to user
Revision ID: b51c6844d1df
Revises: 93c15d6a6fbb
Create Date: 2026-02-11 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b51c6844d1df"
down_revision = "93c15d6a6fbb"
branch_labels = None
depends_on = None
MEMORY_TOOL = {
"name": "MemoryTool",
"display_name": "Add Memory",
"description": "Save memories about the user for future conversations.",
"in_code_tool_id": "MemoryTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
existing = conn.execute(
sa.text(
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id = :in_code_tool_id"
),
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
MEMORY_TOOL,
)
else:
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
MEMORY_TOOL,
)
op.add_column(
"user",
sa.Column(
"enable_memory_tool",
sa.Boolean(),
nullable=False,
server_default=sa.true(),
),
)
def downgrade() -> None:
op.drop_column("user", "enable_memory_tool")
conn = op.get_bind()
conn.execute(
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
)

View File

@@ -0,0 +1,102 @@
"""add_file_reader_tool
Revision ID: d3fd499c829c
Revises: 114a638452db
Create Date: 2026-02-07 19:28:22.452337
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d3fd499c829c"
down_revision = "114a638452db"
branch_labels = None
depends_on = None
FILE_READER_TOOL = {
"name": "read_file",
"display_name": "File Reader",
"description": (
"Read sections of user-uploaded files by character offset. "
"Useful for inspecting large files that cannot fit entirely in context."
),
"in_code_tool_id": "FileReaderTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
# Check if tool already exists
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": FILE_READER_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
FILE_READER_TOOL,
)
tool_id = existing[0]
else:
# Insert new tool
result = conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
RETURNING id
"""
),
FILE_READER_TOOL,
)
tool_id = result.scalar_one()
# Attach to the default persona (id=0) if not already attached
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": tool_id},
)
def downgrade() -> None:
conn = op.get_bind()
in_code_tool_id = FILE_READER_TOOL["in_code_tool_id"]
# Remove persona associations first (FK constraint)
conn.execute(
sa.text(
"""
DELETE FROM persona__tool
WHERE tool_id IN (
SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id
)
"""
),
{"in_code_tool_id": in_code_tool_id},
)
conn.execute(
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": in_code_tool_id},
)

View File

@@ -0,0 +1,69 @@
"""add_opensearch_tenant_migration_columns
Revision ID: feead2911109
Revises: d56ffa94ca32
Create Date: 2026-02-10 17:46:34.029937
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "feead2911109"
down_revision = "175ea04c7087"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column("vespa_visit_continuation_token", sa.Text(), nullable=True),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"total_chunks_migrated",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"migration_completed_at",
sa.DateTime(timezone=True),
nullable=True,
),
)
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"enable_opensearch_retrieval",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "enable_opensearch_retrieval")
op.drop_column("opensearch_tenant_migration_record", "migration_completed_at")
op.drop_column("opensearch_tenant_migration_record", "created_at")
op.drop_column("opensearch_tenant_migration_record", "total_chunks_migrated")
op.drop_column(
"opensearch_tenant_migration_record", "vespa_visit_continuation_token"
)

View File

@@ -1,20 +1,20 @@
The DanswerAI Enterprise license (the Enterprise License)
The Onyx Enterprise License (the "Enterprise License")
Copyright (c) 2023-present DanswerAI, Inc.
With regard to the Onyx Software:
This software and associated documentation files (the "Software") may only be
used in production, if you (and any entity that you represent) have agreed to,
and are in compliance with, the DanswerAI Subscription Terms of Service, available
at https://onyx.app/terms (the Enterprise Terms), or other
and are in compliance with, the Onyx Subscription Terms of Service, available
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
agreement governing the use of the Software, as agreed by you and DanswerAI,
and otherwise have a valid Onyx Enterprise license for the
and otherwise have a valid Onyx Enterprise License for the
correct number of user seats. Subject to the foregoing sentence, you are free to
modify this Software and publish patches to the Software. You agree that DanswerAI
and/or its licensors (as applicable) retain all right, title and interest in and
to all such modifications and/or patches, and all such modifications and/or
patches may only be used, copied, modified, displayed, distributed, or otherwise
exploited with a valid Onyx Enterprise license for the correct
exploited with a valid Onyx Enterprise License for the correct
number of user seats. Notwithstanding the foregoing, you may copy and modify
the Software for development and testing purposes, without requiring a
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain

View File

@@ -1,12 +1,15 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)
)

View File

@@ -1,11 +1,14 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.heavy import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.query_history",
]
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.query_history",
]
)
)

View File

@@ -1,8 +1,11 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.light import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
]
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
]
)
)

View File

@@ -1,7 +1,10 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.monitoring import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.tenant_provisioning",
]
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.tenant_provisioning",
]
)
)

View File

@@ -1,12 +1,15 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.primary import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",
"ee.onyx.background.celery.tasks.ttl_management",
"ee.onyx.background.celery.tasks.usage_reporting",
]
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",
"ee.onyx.background.celery.tasks.ttl_management",
"ee.onyx.background.celery.tasks.usage_reporting",
]
)
)

View File

@@ -536,7 +536,9 @@ def connector_permission_sync_generator_task(
)
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
callback = PermissionSyncCallback(
redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT
)
# pass in the capability to fetch all existing docs for the cc_pair
# this is can be used to determine documents that are "missing" and thus
@@ -576,6 +578,13 @@ def connector_permission_sync_generator_task(
tasks_generated = 0
docs_with_errors = 0
for doc_external_access in document_external_accesses:
if callback.should_stop():
raise RuntimeError(
f"Permission sync task timed out or stop signal detected: "
f"cc_pair={cc_pair_id} "
f"tasks_generated={tasks_generated}"
)
result = redis_connector.permissions.update_db(
lock=lock,
new_permissions=[doc_external_access],
@@ -932,6 +941,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
timeout_seconds: int | None = None,
):
super().__init__()
self.redis_connector: RedisConnector = redis_connector
@@ -944,11 +954,26 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
self.last_tag: str = "PermissionSyncCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.start_monotonic = time.monotonic()
self.timeout_seconds = timeout_seconds
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
if self.timeout_seconds is not None:
elapsed = time.monotonic() - self.start_monotonic
if elapsed > self.timeout_seconds:
logger.warning(
f"PermissionSyncCallback - task timeout exceeded: "
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
f"cc_pair={self.redis_connector.cc_pair_id}"
)
return True
return False
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002

View File

@@ -466,6 +466,7 @@ def connector_external_group_sync_generator_task(
def _perform_external_group_sync(
cc_pair_id: int,
tenant_id: str,
timeout_seconds: int = JOB_TIMEOUT,
) -> None:
# Create attempt record at the start
with get_session_with_current_tenant() as db_session:
@@ -518,9 +519,23 @@ def _perform_external_group_sync(
seen_users: set[str] = set() # Track unique users across all groups
total_groups_processed = 0
total_group_memberships_synced = 0
start_time = time.monotonic()
try:
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
for external_user_group in external_user_group_generator:
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
elapsed = time.monotonic() - start_time
if elapsed > timeout_seconds:
raise RuntimeError(
f"External group sync task timed out: "
f"cc_pair={cc_pair_id} "
f"elapsed={elapsed:.0f}s "
f"timeout={timeout_seconds}s "
f"groups_processed={total_groups_processed}"
)
external_user_group_batch.append(external_user_group)
# Track progress

View File

@@ -65,21 +65,7 @@ def github_doc_sync(
# Get all repositories from GitHub API
logger.info("Fetching all repositories from GitHub API")
try:
repos = []
if github_connector.repositories:
if "," in github_connector.repositories:
# Multiple repositories specified
repos = github_connector.get_github_repos(
github_connector.github_client
)
else:
# Single repository
repos = [
github_connector.get_github_repo(github_connector.github_client)
]
else:
# All repositories
repos = github_connector.get_all_repos(github_connector.github_client)
repos = github_connector.fetch_configured_repos()
logger.info(f"Found {len(repos)} repositories to check")
except Exception as e:

View File

@@ -1,12 +1,9 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -46,16 +43,11 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
# Create client context for the site using connector's MSAL app
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
ctx = connector._create_rest_client_context(site_descriptor.url)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)

View File

@@ -77,7 +77,7 @@ def stream_search_query(
# Get document index
search_settings = get_current_search_settings(db_session)
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None)
document_index = get_default_document_index(search_settings, None, db_session)
# Determine queries to execute
original_query = request.search_query

View File

@@ -109,7 +109,9 @@ async def _make_billing_request(
headers = _get_headers(license_data)
try:
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT, follow_redirects=True
) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
else:

View File

@@ -27,6 +27,8 @@ class SearchFlowClassificationResponse(BaseModel):
is_search_flow: bool
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
class SendSearchQueryRequest(BaseModel):
search_query: str
filters: BaseFilters | None = None

View File

@@ -26,6 +26,7 @@ from onyx.db.models import User
from onyx.llm.factory import get_default_llm
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.server.utils import get_json_line
from onyx.server.utils_vector_db import require_vector_db
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
@@ -66,7 +67,13 @@ def search_flow_classification(
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
@router.post("/send-search-message", response_model=None)
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
@router.post(
"/send-search-message",
response_model=None,
dependencies=[Depends(require_vector_db)],
)
def handle_send_search_message(
request: SendSearchQueryRequest,
user: User = Depends(current_user),

View File

View File

@@ -0,0 +1,96 @@
"""SCIM filter expression parser (RFC 7644 §3.4.2.2).
Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up
resources before deciding whether to create or update them. For example, when
an admin assigns a user to the Onyx app, the IdP first checks whether that
user already exists::
GET /scim/v2/Users?filter=userName eq "john@example.com"
If zero results come back the IdP creates the user (``POST``); if a match is
found it links to the existing record and uses ``PUT``/``PATCH`` going forward.
The same pattern applies to groups (``displayName eq "Engineering"``).
This module parses the subset of the SCIM filter grammar that identity
providers actually send in practice:
attribute SP operator SP value
Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with).
Compound filters (``and`` / ``or``) are not supported; if an IdP sends one
the parser returns ``None`` and the caller falls back to an unfiltered list.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from enum import Enum
class ScimFilterOperator(str, Enum):
"""Supported SCIM filter operators."""
EQUAL = "eq"
CONTAINS = "co"
STARTS_WITH = "sw"
@dataclass(frozen=True, slots=True)
class ScimFilter:
"""Parsed SCIM filter expression."""
attribute: str
operator: ScimFilterOperator
value: str
# Matches: attribute operator "value" (with or without quotes around value)
# Groups: (attribute) (operator) ("quoted value" | unquoted_value)
_FILTER_RE = re.compile(
r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator
r'(?:"([^"]*)"' # quoted value
r"|'([^']*)')" # or single-quoted value
r"$",
re.IGNORECASE,
)
def parse_scim_filter(filter_string: str | None) -> ScimFilter | None:
"""Parse a simple SCIM filter expression.
Args:
filter_string: Raw filter query parameter value, e.g.
``'userName eq "john@example.com"'``
Returns:
A ``ScimFilter`` if the expression is valid and uses a supported
operator, or ``None`` if the input is empty / missing.
Raises:
ValueError: If the filter string is present but malformed or uses
an unsupported operator.
"""
if not filter_string or not filter_string.strip():
return None
match = _FILTER_RE.match(filter_string.strip())
if not match:
raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}")
return _build_filter(match, filter_string)
def _build_filter(match: re.Match[str], raw: str) -> ScimFilter:
"""Extract fields from a regex match and construct a ScimFilter."""
attribute = match.group(1)
op_str = match.group(2).lower()
# Value is in group 3 (double-quoted) or group 4 (single-quoted)
value = match.group(3) if match.group(3) is not None else match.group(4)
if value is None:
raise ValueError(f"Unsupported or malformed SCIM filter: {raw}")
operator = ScimFilterOperator(op_str)
return ScimFilter(attribute=attribute, operator=operator, value=value)

View File

@@ -0,0 +1,255 @@
"""Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644).
SCIM protocol schemas follow the wire format defined in:
- Core Schema: https://datatracker.ietf.org/doc/html/rfc7643
- Protocol: https://datatracker.ietf.org/doc/html/rfc7644
Admin API schemas are internal to Onyx and used for SCIM token management.
"""
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
# ---------------------------------------------------------------------------
# SCIM Schema URIs (RFC 7643 §8)
# Every SCIM JSON payload includes a "schemas" array identifying its type.
# IdPs like Okta/Azure AD use these URIs to determine how to parse responses.
# ---------------------------------------------------------------------------
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
# ---------------------------------------------------------------------------
# SCIM Protocol Schemas
# ---------------------------------------------------------------------------
class ScimName(BaseModel):
"""User name components (RFC 7643 §4.1.1)."""
givenName: str | None = None
familyName: str | None = None
formatted: str | None = None
class ScimEmail(BaseModel):
"""Email sub-attribute (RFC 7643 §4.1.2)."""
value: str
type: str | None = None
primary: bool = False
class ScimMeta(BaseModel):
"""Resource metadata (RFC 7643 §3.1)."""
resourceType: str | None = None
created: datetime | None = None
lastModified: datetime | None = None
location: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
This is the JSON shape that IdPs send when creating/updating a user via
SCIM, and the shape we return in GET responses. Field names use camelCase
to match the SCIM wire format (not Python convention).
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
id: str | None = None # Onyx's internal user ID, set on responses
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
meta: ScimMeta | None = None
class ScimGroupMember(BaseModel):
"""Group member reference (RFC 7643 §4.2).
Represents a user within a SCIM group. The IdP sends these when adding
or removing users from groups. ``value`` is the Onyx user ID.
"""
value: str # User ID of the group member
display: str | None = None
class ScimGroupResource(BaseModel):
"""SCIM Group resource representation (RFC 7643 §4.2)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA])
id: str | None = None
externalId: str | None = None
displayName: str
members: list[ScimGroupMember] = Field(default_factory=list)
meta: ScimMeta | None = None
class ScimListResponse(BaseModel):
"""Paginated list response (RFC 7644 §3.4.2)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA])
totalResults: int
startIndex: int = 1
itemsPerPage: int = 100
Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list)
class ScimPatchOperationType(str, Enum):
"""Supported PATCH operations (RFC 7644 §3.5.2)."""
ADD = "add"
REPLACE = "replace"
REMOVE = "remove"
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
class ScimPatchRequest(BaseModel):
"""PATCH request body (RFC 7644 §3.5.2).
IdPs use PATCH to make incremental changes — e.g. deactivating a user
(replace active=false) or adding/removing group members — instead of
replacing the entire resource with PUT.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA])
Operations: list[ScimPatchOperation]
class ScimError(BaseModel):
"""SCIM error response (RFC 7644 §3.12)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA])
status: str
detail: str | None = None
scimType: str | None = None
# ---------------------------------------------------------------------------
# Service Provider Configuration (RFC 7643 §5)
# ---------------------------------------------------------------------------
class ScimSupported(BaseModel):
"""Generic supported/not-supported flag used in ServiceProviderConfig."""
supported: bool
class ScimFilterConfig(BaseModel):
"""Filter configuration within ServiceProviderConfig (RFC 7643 §5)."""
supported: bool
maxResults: int = 100
class ScimServiceProviderConfig(BaseModel):
"""SCIM ServiceProviderConfig resource (RFC 7643 §5).
Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during
initial setup to discover which SCIM features our server supports
(e.g. PATCH yes, bulk no, filtering yes).
"""
schemas: list[str] = Field(
default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA]
)
patch: ScimSupported = ScimSupported(supported=True)
bulk: ScimSupported = ScimSupported(supported=False)
filter: ScimFilterConfig = ScimFilterConfig(supported=True)
changePassword: ScimSupported = ScimSupported(supported=False)
sort: ScimSupported = ScimSupported(supported=False)
etag: ScimSupported = ScimSupported(supported=False)
authenticationSchemes: list[dict[str, str]] = Field(
default_factory=lambda: [
{
"type": "oauthbearertoken",
"name": "OAuth Bearer Token",
"description": "Authentication scheme using a SCIM bearer token",
}
]
)
class ScimSchemaExtension(BaseModel):
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
model_config = ConfigDict(populate_by_name=True)
schema_: str = Field(alias="schema")
required: bool
class ScimResourceType(BaseModel):
"""SCIM ResourceType resource (RFC 7643 §6).
Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource
types are available (Users, Groups) and their respective endpoints.
"""
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
id: str
name: str
endpoint: str
description: str | None = None
schema_: str = Field(alias="schema")
schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# Admin API Schemas (Onyx-internal, for SCIM token management)
# These are NOT part of the SCIM protocol. They power the Onyx admin UI
# where admins create/revoke the bearer tokens that IdPs use to authenticate.
# ---------------------------------------------------------------------------
class ScimTokenCreate(BaseModel):
"""Request to create a new SCIM bearer token."""
name: str
class ScimTokenResponse(BaseModel):
"""SCIM token metadata returned in list/get responses."""
id: int
name: str
token_display: str
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):
"""Response returned when a new SCIM token is created.
Includes the raw token value which is only available at creation time.
"""
raw_token: str

View File

@@ -0,0 +1,256 @@
"""SCIM PATCH operation handler (RFC 7644 §3.5.2).
Identity providers use PATCH to make incremental changes to SCIM resources
instead of replacing the entire resource with PUT. Common operations include:
- Deactivating a user: ``replace`` ``active`` with ``false``
- Adding group members: ``add`` to ``members``
- Removing group members: ``remove`` from ``members[value eq "..."]``
This module applies PATCH operations to Pydantic SCIM resource objects and
returns the modified result. It does NOT touch the database — the caller is
responsible for persisting changes.
"""
from __future__ import annotations
import re
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimUserResource
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
def __init__(self, detail: str, status: int = 400) -> None:
self.detail = detail
self.status = status
super().__init__(detail)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
) -> ScimUserResource:
"""Apply SCIM PATCH operations to a user resource.
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
name_data = data.get("name") or {}
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
data["name"] = name_data
return ScimUserResource.model_validate(data)
def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data)
def _set_user_field(
path: str,
value: str | bool | dict | list | None,
data: dict,
name_data: dict,
) -> None:
"""Set a single field on user data by SCIM path."""
if path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
elif path == "externalid":
data["externalId"] = value
elif path == "name.givenname":
name_data["givenName"] = value
elif path == "name.familyname":
name_data["familyName"] = value
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
# Some IdPs send displayName on users; map to formatted name
name_data["formatted"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
current_members: list[dict] = list(data.get("members") or [])
added_ids: list[str] = []
removed_ids: list[str] = []
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_group_remove(op, current_members, removed_ids)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on Group resource"
)
data["members"] = current_members
group = ScimGroupResource.model_validate(data)
return group, added_ids, removed_ids
def _apply_group_replace(
op: ScimPatchOperation,
data: dict,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, dict):
for key, val in op.value.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(op.value, current_members, added_ids, removed_ids)
return
_set_group_field(path, op.value, data)
def _replace_members(
value: str | list | dict | bool | None,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
removed_ids.extend(old_ids - new_ids)
added_ids.extend(new_ids - old_ids)
current_members[:] = value
def _set_group_field(
path: str,
value: str | bool | dict | list | None,
data: dict,
) -> None:
"""Set a single field on group data by SCIM path."""
if path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(
op: ScimPatchOperation,
members: list[dict],
added_ids: list[str],
) -> None:
"""Add members to a group."""
path = (op.path or "").lower()
if path and path != "members":
raise ScimPatchError(f"Unsupported add path '{op.path}' for Group")
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
existing_ids = {m["value"] for m in members}
for member_data in op.value:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)
added_ids.append(member_id)
existing_ids.add(member_id)
def _apply_group_remove(
op: ScimPatchOperation,
members: list[dict],
removed_ids: list[str],
) -> None:
"""Remove members from a group."""
if not op.path:
raise ScimPatchError("Remove operation requires a path")
match = _MEMBER_FILTER_RE.match(op.path)
if not match:
raise ScimPatchError(
f"Unsupported remove path '{op.path}'. "
'Expected: members[value eq "user-id"]'
)
target_id = match.group(1)
original_len = len(members)
members[:] = [m for m in members if m.get("value") != target_id]
if len(members) < original_len:
removed_ids.append(target_id)

View File

@@ -4,6 +4,7 @@ from redis.exceptions import RedisError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -89,7 +90,11 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True
else:
# No license = community edition, disable EE features
# No license found.
if ENTERPRISE_EDITION_ENABLED:
# Legacy EE flag is set → prior EE usage (e.g. permission
# syncing) means indexed data may need protection.
settings.application_status = _BLOCKING_STATUS
settings.ee_features_enabled = False
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")

View File

@@ -177,7 +177,7 @@ async def forward_to_control_plane(
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
elif method == "POST":

View File

@@ -12,12 +12,14 @@ from ee.onyx.db.user_group import prepare_user_group_for_deletion
from ee.onyx.db.user_group import update_user_curator_relationship
from ee.onyx.db.user_group import update_user_group
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroup
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
@@ -45,6 +47,23 @@ def list_user_groups(
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.get("/user-groups/minimal")
def list_minimal_user_groups(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserGroupSnapshot]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
)
return [
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
]
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,

View File

@@ -76,6 +76,18 @@ class UserGroup(BaseModel):
)
class MinimalUserGroupSnapshot(BaseModel):
id: int
name: str
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
return cls(
id=user_group_model.id,
name=user_group_model.name,
)
class UserGroupCreate(BaseModel):
name: str
user_ids: list[UUID]

View File

@@ -1,7 +1,9 @@
import uuid
from enum import Enum
from typing import Any
from fastapi_users import schemas
from typing_extensions import override
class UserRole(str, Enum):
@@ -41,8 +43,21 @@ class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
captcha_token: str | None = None
@override
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
return d
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole

View File

@@ -60,6 +60,7 @@ from sqlalchemy import nulls_last
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.disposable_email_validator import is_disposable_email
@@ -110,6 +111,7 @@ from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -272,6 +274,22 @@ def verify_email_domain(email: str) -> None:
)
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
"""Raise HTTPException(402) if adding users would exceed the seat limit.
No-op for multi-tenant or CE deployments.
"""
if MULTI_TENANT:
return
result = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)(db_session, seats_needed=seats_needed)
if result is not None and not result.available:
raise HTTPException(status_code=402, detail=result.error_message)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
@@ -401,6 +419,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
):
user_create.role = UserRole.ADMIN
# Check seat availability for new users (single-tenant only)
with get_session_with_current_tenant() as sync_db:
existing = get_user_by_email(user_create.email, sync_db)
if existing is None:
enforce_seat_limit(sync_db)
user_created = False
try:
user = await super().create(user_create, safe=safe, request=request)
@@ -610,6 +634,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
raise exceptions.UserNotExists()
except exceptions.UserNotExists:
# Check seat availability before creating (single-tenant only)
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
password = self.password_helper.generate()
user_dict = {
"email": account_email,
@@ -1431,6 +1459,7 @@ def get_anonymous_user() -> User:
is_superuser=False,
role=UserRole.LIMITED,
use_memories=False,
enable_memory_tool=False,
)
return user

View File

@@ -26,6 +26,7 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
@@ -525,6 +526,12 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None: # noqa: ARG
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if DISABLE_VECTOR_DB:
logger.info(
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
)
return
if not wait_for_vespa_with_timeout():
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
@@ -566,3 +573,31 @@ class LivenessProbe(bootsteps.StartStopStep):
def get_bootsteps() -> list[type]:
return [LivenessProbe]
# Task modules that require a vector DB (Vespa/OpenSearch).
# When DISABLE_VECTOR_DB is True these are excluded from autodiscover lists.
_VECTOR_DB_TASK_MODULES: set[str] = {
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.docfetching",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.opensearch_migration",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.hierarchyfetching",
# EE modules that are vector-DB-dependent
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
}
# NOTE: "onyx.background.celery.tasks.shared" is intentionally NOT in the set
# above. It contains celery_beat_heartbeat (which only writes to Redis) alongside
# document cleanup tasks. The cleanup tasks won't be invoked in minimal mode
# because the periodic tasks that trigger them are in other filtered modules.
def filter_task_modules(modules: list[str]) -> list[str]:
"""Remove vector-DB-dependent task modules when DISABLE_VECTOR_DB is True."""
if not DISABLE_VECTOR_DB:
return modules
return [m for m in modules if m not in _VECTOR_DB_TASK_MODULES]

View File

@@ -118,23 +118,25 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.opensearch_migration",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
app_base.filter_task_modules(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.opensearch_migration",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
)

View File

@@ -96,7 +96,9 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
]
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.docfetching",
]
)
)

View File

@@ -107,7 +107,9 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docprocessing",
]
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.docprocessing",
]
)
)

View File

@@ -96,10 +96,12 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.pruning",
# Sandbox tasks (file sync, cleanup)
"onyx.server.features.build.sandbox.tasks",
"onyx.background.celery.tasks.hierarchyfetching",
]
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.pruning",
# Sandbox tasks (file sync, cleanup)
"onyx.server.features.build.sandbox.tasks",
"onyx.background.celery.tasks.hierarchyfetching",
]
)
)

View File

@@ -110,13 +110,16 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.docprocessing",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.opensearch_migration",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
)

View File

@@ -94,7 +94,9 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
]
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.monitoring",
]
)
)

View File

@@ -314,17 +314,18 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.opensearch_migration",
]
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_processing",
]
)
)

View File

@@ -107,7 +107,9 @@ for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.user_file_processing",
]
app_base.filter_task_modules(
[
"onyx.background.celery.tasks.user_file_processing",
]
)
)

View File

@@ -5,17 +5,19 @@ from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
from typing import TypeVar
import httpx
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
@@ -31,6 +33,54 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def _checkpointed_batched_doc_ids(
connector: CheckpointedConnector[CT],
start: float,
end: float,
batch_size: int,
) -> Generator[set[str], None, None]:
"""Loop through all checkpoint steps and yield batched document IDs.
Some checkpointed connectors (e.g. IMAP) are multi-step: the first
checkpoint call may only initialize internal state without yielding
any documents. This function loops until checkpoint.has_more is False
to ensure all document IDs are collected across every step.
"""
checkpoint = connector.build_dummy_checkpoint()
while True:
checkpoint_output = connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
wrapper: CheckpointOutputWrapper[CT] = CheckpointOutputWrapper()
batch: set[str] = set()
for document, _hierarchy_node, failure, next_checkpoint in wrapper(
checkpoint_output
):
if document is not None:
batch.add(document.id)
elif (
failure
and failure.failed_document
and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if next_checkpoint is not None:
checkpoint = next_checkpoint
if len(batch) >= batch_size:
yield batch
batch = set()
if batch:
yield batch
if not checkpoint.has_more:
break
def document_batch_to_ids(
doc_batch: (
@@ -80,12 +130,8 @@ def extract_ids_from_runnable_connector(
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
doc_batch_id_generator = _checkpointed_batched_doc_ids(
runnable_connector, start, end, PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")

View File

@@ -6,6 +6,7 @@ from celery.schedules import crontab
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
@@ -215,36 +216,39 @@ if SCHEDULED_EVAL_DATASET_NAMES:
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
beat_task_templates.append(
{
"name": "check-for-documents-for-opensearch-migration",
"task": OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
"name": "migrate-chunks-from-vespa-to-opensearch",
"task": OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK,
# Try to enqueue an invocation of this task with this frequency.
"schedule": timedelta(seconds=120), # 2 minutes
"options": {
"priority": OnyxCeleryPriority.LOW,
# If the task was not dequeued in this time, revoke it.
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
},
}
)
beat_task_templates.append(
{
"name": "migrate-documents-from-vespa-to-opensearch",
"task": OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
# Try to enqueue an invocation of this task with this frequency.
# NOTE: If MIGRATION_TASK_SOFT_TIME_LIMIT_S is greater than this
# value and the task is maximally busy, we can expect to see some
# enqueued tasks be revoked over time. This is ok; by erring on the
# side of "there will probably always be at least one task of this
# type in the queue", we are minimizing this task's idleness while
# still giving chances for other tasks to execute.
"schedule": timedelta(seconds=120), # 2 minutes
"options": {
"priority": OnyxCeleryPriority.LOW,
# If the task was not dequeued in this time, revoke it.
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
# Beat task names that require a vector DB. Filtered out when DISABLE_VECTOR_DB.
_VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
"check-for-indexing",
"check-for-connector-deletion",
"check-for-vespa-sync",
"check-for-pruning",
"check-for-hierarchy-fetching",
"check-for-checkpoint-cleanup",
"check-for-index-attempt-cleanup",
"check-for-doc-permissions-sync",
"check-for-external-group-sync",
"check-for-documents-for-opensearch-migration",
"migrate-documents-from-vespa-to-opensearch",
}
if DISABLE_VECTOR_DB:
beat_task_templates = [
t for t in beat_task_templates if t["name"] not in _VECTOR_DB_BEAT_TASK_NAMES
]
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:

View File

@@ -37,6 +37,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
timeout_seconds: int | None = None,
):
super().__init__()
self.parent_pid = parent_pid
@@ -51,11 +52,29 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_lock_monotonic = time.monotonic()
self.last_parent_check = time.monotonic()
self.start_monotonic = time.monotonic()
self.timeout_seconds = timeout_seconds
def should_stop(self) -> bool:
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
if bool(self.redis_connector.stop.fenced):
return True
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
if self.timeout_seconds is not None:
elapsed = time.monotonic() - self.start_monotonic
if elapsed > self.timeout_seconds:
logger.warning(
f"IndexingCallback Docprocessing - task timeout exceeded: "
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
f"cc_pair={self.redis_connector.cc_pair_id}"
)
return True
return False
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
"""Amount isn't used yet."""

View File

@@ -0,0 +1,10 @@
"""Celery tasks for hierarchy fetching."""
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
check_for_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
connector_hierarchy_fetching_task,
)
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]

View File

@@ -146,14 +146,26 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
"""Collect metrics about queue lengths for different Celery queues"""
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"docprocessing_queue_length": "docprocessing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
"celery_queue_length": OnyxCeleryQueues.PRIMARY,
"docprocessing_queue_length": OnyxCeleryQueues.DOCPROCESSING,
"docfetching_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
"sync_queue_length": OnyxCeleryQueues.VESPA_METADATA_SYNC,
"deletion_queue_length": OnyxCeleryQueues.CONNECTOR_DELETION,
"pruning_queue_length": OnyxCeleryQueues.CONNECTOR_PRUNING,
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
"hierarchy_fetching_queue_length": OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
"llm_model_update_queue_length": OnyxCeleryQueues.LLM_MODEL_UPDATE,
"checkpoint_cleanup_queue_length": OnyxCeleryQueues.CHECKPOINT_CLEANUP,
"index_attempt_cleanup_queue_length": OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP,
"csv_generation_queue_length": OnyxCeleryQueues.CSV_GENERATION,
"user_file_processing_queue_length": OnyxCeleryQueues.USER_FILE_PROCESSING,
"user_file_project_sync_queue_length": OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
"user_file_delete_queue_length": OnyxCeleryQueues.USER_FILE_DELETE,
"monitoring_queue_length": OnyxCeleryQueues.MONITORING,
"sandbox_queue_length": OnyxCeleryQueues.SANDBOX,
"opensearch_migration_queue_length": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
}
for name, queue in queue_mappings.items():
@@ -881,7 +893,7 @@ def monitor_celery_queues_helper(
"""A task to monitor all celery queue lengths."""
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
@@ -908,6 +920,26 @@ def monitor_celery_queues_helper(
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_hierarchy_fetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, r_celery
)
n_llm_model_update = celery_get_queue_length(
OnyxCeleryQueues.LLM_MODEL_UPDATE, r_celery
)
n_checkpoint_cleanup = celery_get_queue_length(
OnyxCeleryQueues.CHECKPOINT_CLEANUP, r_celery
)
n_index_attempt_cleanup = celery_get_queue_length(
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, r_celery
)
n_csv_generation = celery_get_queue_length(
OnyxCeleryQueues.CSV_GENERATION, r_celery
)
n_monitoring = celery_get_queue_length(OnyxCeleryQueues.MONITORING, r_celery)
n_sandbox = celery_get_queue_length(OnyxCeleryQueues.SANDBOX, r_celery)
n_opensearch_migration = celery_get_queue_length(
OnyxCeleryQueues.OPENSEARCH_MIGRATION, r_celery
)
n_docfetching_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
@@ -931,6 +963,14 @@ def monitor_celery_queues_helper(
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
f"hierarchy_fetching={n_hierarchy_fetching} "
f"llm_model_update={n_llm_model_update} "
f"checkpoint_cleanup={n_checkpoint_cleanup} "
f"index_attempt_cleanup={n_index_attempt_cleanup} "
f"csv_generation={n_csv_generation} "
f"monitoring={n_monitoring} "
f"sandbox={n_sandbox} "
f"opensearch_migration={n_opensearch_migration} "
)

View File

@@ -2,27 +2,12 @@
import time
import traceback
from datetime import datetime
from datetime import timezone
from typing import Any
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
@@ -42,225 +27,32 @@ from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import OpenSearchDocumentMigrationStatus
from onyx.db.opensearch_migration import create_opensearch_migration_records_with_commit
from onyx.db.opensearch_migration import get_last_opensearch_migration_document_id
from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping
from onyx.db.opensearch_migration import get_vespa_visit_state
from onyx.db.opensearch_migration import (
get_opensearch_migration_records_needing_migration,
mark_migration_completed_time_if_not_set_with_commit,
)
from onyx.db.opensearch_migration import get_paginated_document_batch
from onyx.db.opensearch_migration import (
increment_num_times_observed_no_additional_docs_to_migrate_with_commit,
)
from onyx.db.opensearch_migration import (
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit,
)
from onyx.db.opensearch_migration import should_document_migration_be_permanently_failed
from onyx.db.opensearch_migration import (
try_insert_opensearch_tenant_migration_record_with_commit,
)
from onyx.db.opensearch_migration import update_vespa_visit_progress_with_commit
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.opensearch.schema import DocumentChunk
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
def _migrate_single_document(
document_id: str,
opensearch_document_index: OpenSearchDocumentIndex,
vespa_document_index: VespaDocumentIndex,
tenant_state: TenantState,
) -> int:
"""Migrates a single document from Vespa to OpenSearch.
Args:
document_id: The ID of the document to migrate.
opensearch_document_index: The OpenSearch document index to use.
vespa_document_index: The Vespa document index to use.
tenant_state: The tenant state to use.
Raises:
RuntimeError: If no chunks are found for the document in Vespa, or if
the number of candidate chunks to migrate does not match the number
of chunks in Vespa.
Returns:
The number of chunks migrated.
"""
vespa_document_chunks: list[dict[str, Any]] = (
vespa_document_index.get_raw_document_chunks(document_id=document_id)
)
if not vespa_document_chunks:
raise RuntimeError(f"No chunks found for document {document_id} in Vespa.")
opensearch_document_chunks: list[DocumentChunk] = (
transform_vespa_chunks_to_opensearch_chunks(
vespa_document_chunks, tenant_state, document_id
)
)
if len(opensearch_document_chunks) != len(vespa_document_chunks):
raise RuntimeError(
f"Bug: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does not match "
f"number of chunks in Vespa ({len(vespa_document_chunks)})."
)
opensearch_document_index.index_raw_chunks(chunks=opensearch_document_chunks)
return len(opensearch_document_chunks)
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
# shared_task allows this task to be shared across celery app instances.
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
# Does not store the task's return value in the result backend.
ignore_result=True,
# WARNING: This is here just for rigor but since we use threads for Celery
# this config is not respected and timeout logic must be implemented in the
# task.
soft_time_limit=CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
# WARNING: This is here just for rigor but since we use threads for Celery
# this config is not respected and timeout logic must be implemented in the
# task.
time_limit=CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
# Passed in self to the task to get task metadata.
bind=True,
)
def check_for_documents_for_opensearch_migration_task(
self: Task, *, tenant_id: str # noqa: ARG001
) -> bool | None:
"""
Periodic task to check for and add documents to the OpenSearch migration
table.
Should not execute meaningful logic at the same time as
migrate_documents_from_vespa_to_opensearch_task.
Effectively tries to populate as many migration records as possible within
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of
1000 documents.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
acquired; effectively a no-op. True if the task completed
successfully. False if the task failed.
"""
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
task_logger.warning(
"OpenSearch migration is not enabled, skipping check for documents for the OpenSearch migration task."
)
return None
task_logger.info("Checking for documents for OpenSearch migration.")
task_start_time = time.monotonic()
r = get_redis_client()
# Use a lock to prevent overlapping tasks. Only this task or
# migrate_documents_from_vespa_to_opensearch_task can interact with the
# OpenSearchMigration table at once.
lock: RedisLock = r.lock(
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
# The maximum time the lock can be held for. Will automatically be
# released after this time.
timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
# .acquire will block until the lock is acquired.
blocking=True,
# Time to wait to acquire the lock.
blocking_timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
if not lock.acquire():
task_logger.warning(
"The OpenSearch migration check task timed out waiting for the lock."
)
return None
else:
task_logger.info(
f"Acquired the OpenSearch migration check lock. Took {time.monotonic() - task_start_time:.3f} seconds. "
f"Token: {lock.local.token}"
)
num_documents_found_for_record_creation = 0
try:
# Double check that tenant info is correct.
if tenant_id != get_current_tenant_id():
err_str = (
f"Tenant ID mismatch in the OpenSearch migration check task: "
f"{tenant_id} != {get_current_tenant_id()}. This should never happen."
)
task_logger.error(err_str)
return False
while (
time.monotonic() - task_start_time
< CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
with get_session_with_current_tenant() as db_session:
# For pagination, get the last ID we've inserted into
# OpenSearchMigration.
last_opensearch_migration_document_id = (
get_last_opensearch_migration_document_id(db_session)
)
# Now get the next batch of doc IDs starting after the last ID.
# We'll do 1000 documents per transaction/timeout check.
document_ids = get_paginated_document_batch(
db_session,
limit=1000,
prev_ending_document_id=last_opensearch_migration_document_id,
)
if not document_ids:
task_logger.info(
"No more documents to insert for OpenSearch migration."
)
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit(
db_session
)
# TODO(andrei): Once we've done this enough times and the
# number of documents matches the number of migration
# records, we can be done with this task and update
# document_migration_record_table_population_status.
return True
# Create the migration records for the next batch of documents
# with status PENDING.
create_opensearch_migration_records_with_commit(
db_session, document_ids
)
num_documents_found_for_record_creation += len(document_ids)
# Try to create the singleton row in
# OpenSearchTenantMigrationRecord if it doesn't already exist.
# This is a reasonable place to put it because we already have a
# lock, a session, and error handling, at the cost of running
# this small set of logic for every batch.
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
except Exception:
task_logger.exception("Error in the OpenSearch migration check task.")
return False
finally:
if lock.owned():
lock.release()
else:
task_logger.warning(
"The OpenSearch migration lock was not owned on completion of the check task."
)
task_logger.info(
f"Finished checking for documents for OpenSearch migration. Found {num_documents_found_for_record_creation} documents "
f"to create migration records for in {time.monotonic() - task_start_time:.3f} seconds. However, this may include "
"documents for which there already exist records."
)
return True
# shared_task allows this task to be shared across celery app instances.
@shared_task(
name=OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
name=OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK,
# Does not store the task's return value in the result backend.
ignore_result=True,
# WARNING: This is here just for rigor but since we use threads for Celery
@@ -274,18 +66,21 @@ def check_for_documents_for_opensearch_migration_task(
# Passed in self to the task to get task metadata.
bind=True,
)
def migrate_documents_from_vespa_to_opensearch_task(
def migrate_chunks_from_vespa_to_opensearch_task(
self: Task, # noqa: ARG001
*,
tenant_id: str,
) -> bool | None:
"""Periodic task to migrate documents from Vespa to OpenSearch.
"""
Periodic task to migrate chunks from Vespa to OpenSearch via the Visit API.
Should not execute meaningful logic at the same time as
check_for_documents_for_opensearch_migration_task.
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
per-document), transform them, and index them into OpenSearch. Progress is
tracked via a continuation token stored in the
OpenSearchTenantMigrationRecord.
Effectively tries to migrate as many documents as possible within
MIGRATION_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of 5 documents.
The first time we see no continuation token and non-zero chunks migrated, we
consider the migration complete and all subsequent invocations are no-ops.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -294,16 +89,13 @@ def migrate_documents_from_vespa_to_opensearch_task(
"""
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
task_logger.warning(
"OpenSearch migration is not enabled, skipping trying to migrate documents from Vespa to OpenSearch."
"OpenSearch migration is not enabled, skipping chunk migration task."
)
return None
task_logger.info("Trying a migration batch from Vespa to OpenSearch.")
task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.")
task_start_time = time.monotonic()
r = get_redis_client()
# Use a lock to prevent overlapping tasks. Only this task or
# check_for_documents_for_opensearch_migration_task can interact with the
# OpenSearchMigration table at once.
lock: RedisLock = r.lock(
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
# The maximum time the lock can be held for. Will automatically be
@@ -325,9 +117,8 @@ def migrate_documents_from_vespa_to_opensearch_task(
f"Token: {lock.local.token}"
)
num_documents_migrated = 0
num_chunks_migrated = 0
num_documents_failed = 0
total_chunks_migrated_this_task = 0
total_chunks_errored_this_task = 0
try:
# Double check that tenant info is correct.
if tenant_id != get_current_tenant_id():
@@ -337,97 +128,100 @@ def migrate_documents_from_vespa_to_opensearch_task(
)
task_logger.error(err_str)
return False
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
with get_session_with_current_tenant() as db_session:
# We'll do 5 documents per transaction/timeout check.
records_needing_migration = (
get_opensearch_migration_records_needing_migration(
db_session, limit=5
)
)
if not records_needing_migration:
with get_session_with_current_tenant() as db_session:
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
opensearch_document_index = OpenSearchDocumentIndex(
index_name=search_settings.index_name, tenant_state=tenant_state
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,
tenant_state=tenant_state,
large_chunks_enabled=False,
)
sanitized_doc_start_time = time.monotonic()
# We reconstruct this mapping for every task invocation because a
# document may have been added in the time between two tasks.
sanitized_to_original_doc_id_mapping = (
build_sanitized_to_original_doc_id_mapping(db_session)
)
task_logger.debug(
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if continuation_token is None and total_chunks_migrated > 0:
task_logger.info(
"No documents found that need to be migrated from Vespa to OpenSearch."
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
f"Total chunks migrated: {total_chunks_migrated}."
)
increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
db_session
mark_migration_completed_time_if_not_set_with_commit(db_session)
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token: {continuation_token}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token = (
vespa_document_index.get_all_raw_document_chunks_paginated(
continuation_token=continuation_token,
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
)
# TODO(andrei): Once we've done this enough times and
# document_migration_record_table_population_status is done, we
# can be done with this task and update
# overall_document_migration_status accordingly. Note that this
# includes marking connectors as needing reindexing if some
# migrations failed.
return True
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(
tenant_id=tenant_id, multitenant=MULTI_TENANT
)
opensearch_document_index = OpenSearchDocumentIndex(
index_name=search_settings.index_name, tenant_state=tenant_state
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,
tenant_state=tenant_state,
large_chunks_enabled=False,
task_logger.debug(
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
f"seconds. Next continuation token: {next_continuation_token}"
)
for record in records_needing_migration:
try:
# If the Document's chunk count is not known, it was
# probably just indexed so fail here to give it a chance to
# sync. If in the rare event this Document has not been
# re-indexed in a very long time and is still under the
# "old" embedding/indexing logic where chunk count was never
# stored, we will eventually permanently fail and thus force
# a re-index of this doc, which is a desireable outcome.
if record.document.chunk_count is None:
raise RuntimeError(
f"Document {record.document_id} has no chunk count."
)
opensearch_document_chunks, errored_chunks = (
transform_vespa_chunks_to_opensearch_chunks(
raw_vespa_chunks,
tenant_state,
sanitized_to_original_doc_id_mapping,
)
)
if len(opensearch_document_chunks) != len(raw_vespa_chunks):
task_logger.error(
f"Migration task error: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does "
f"not match number of chunks in Vespa ({len(raw_vespa_chunks)}). {len(errored_chunks)} chunks "
"errored."
)
chunks_migrated = _migrate_single_document(
document_id=record.document_id,
opensearch_document_index=opensearch_document_index,
vespa_document_index=vespa_document_index,
tenant_state=tenant_state,
)
index_opensearch_chunks_start_time = time.monotonic()
opensearch_document_index.index_raw_chunks(
chunks=opensearch_document_chunks
)
task_logger.debug(
f"Indexed {len(opensearch_document_chunks)} chunks into OpenSearch in "
f"{time.monotonic() - index_opensearch_chunks_start_time:.3f} seconds."
)
# If the number of chunks in Vespa is not in sync with the
# Document table for this doc let's not consider this
# completed and let's let a subsequent run take care of it.
if chunks_migrated != record.document.chunk_count:
raise RuntimeError(
f"Number of chunks migrated ({chunks_migrated}) does not match number of expected chunks "
f"in Vespa ({record.document.chunk_count}) for document {record.document_id}."
)
total_chunks_migrated_this_task += len(opensearch_document_chunks)
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token=next_continuation_token,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
)
record.status = OpenSearchDocumentMigrationStatus.COMPLETED
num_documents_migrated += 1
num_chunks_migrated += chunks_migrated
except Exception:
record.status = OpenSearchDocumentMigrationStatus.FAILED
record.error_message = f"Attempt {record.attempts_count + 1}:\n{traceback.format_exc()}"
task_logger.exception(
f"Error migrating document {record.document_id} from Vespa to OpenSearch."
)
num_documents_failed += 1
finally:
record.attempts_count += 1
record.last_attempt_at = datetime.now(timezone.utc)
if should_document_migration_be_permanently_failed(record):
record.status = (
OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
)
# TODO(andrei): Not necessarily here but if this happens
# we'll need to mark the connector as needing reindex.
db_session.commit()
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
task_logger.info("Vespa reported no more chunks to migrate.")
break
except Exception:
traceback.print_exc()
task_logger.exception("Error in the OpenSearch migration task.")
return False
finally:
@@ -439,9 +233,11 @@ def migrate_documents_from_vespa_to_opensearch_task(
)
task_logger.info(
f"Finished a migration batch from Vespa to OpenSearch. Migrated {num_chunks_migrated} chunks "
f"from {num_documents_migrated} documents in {time.monotonic() - task_start_time:.3f} seconds. "
f"Failed to migrate {num_documents_failed} documents."
f"OpenSearch chunk migration task pausing (time limit reached). "
f"Total chunks migrated this task: {total_chunks_migrated_this_task}. "
f"Total chunks errored this task: {total_chunks_errored_this_task}. "
f"Elapsed: {time.monotonic() - task_start_time:.3f}s. "
"Will resume from continuation token on next invocation."
)
return True

View File

@@ -1,3 +1,4 @@
import traceback
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -140,9 +141,7 @@ def _transform_vespa_acl_to_opensearch_acl(
vespa_acl: dict[str, int] | None,
) -> tuple[bool, list[str]]:
if not vespa_acl:
raise ValueError(
"Missing ACL in Vespa chunk. This does not make sense as it implies the document is never searchable by anyone ever."
)
return False, []
acl_list = list(vespa_acl.keys())
is_public = PUBLIC_DOC_PAT in acl_list
if is_public:
@@ -153,133 +152,163 @@ def _transform_vespa_acl_to_opensearch_acl(
def transform_vespa_chunks_to_opensearch_chunks(
vespa_chunks: list[dict[str, Any]],
tenant_state: TenantState,
document_id: str,
) -> list[DocumentChunk]:
sanitized_to_original_doc_id_mapping: dict[str, str],
) -> tuple[list[DocumentChunk], list[dict[str, Any]]]:
result: list[DocumentChunk] = []
errored_chunks: list[dict[str, Any]] = []
for vespa_chunk in vespa_chunks:
# This should exist; fail loudly if it does not.
vespa_document_id: str = vespa_chunk[DOCUMENT_ID]
if not vespa_document_id:
raise ValueError("Missing document_id in Vespa chunk.")
# Vespa doc IDs were sanitized using replace_invalid_doc_id_characters.
# This was a poor design choice and we don't want this in OpenSearch;
# whatever restrictions there may be on indexed chunk ID should have no
# bearing on the chunk's document ID field, even if document ID is an
# argument to the chunk ID. Deliberately choose to use the real doc ID
# supplied to this function.
if vespa_document_id != document_id:
logger.warning(
f"Vespa document ID {vespa_document_id} does not match the document ID supplied {document_id}. "
"The Vespa ID will be discarded."
try:
# This should exist; fail loudly if it does not.
vespa_document_id: str = vespa_chunk[DOCUMENT_ID]
if not vespa_document_id:
raise ValueError("Missing document_id in Vespa chunk.")
# Vespa doc IDs were sanitized using
# replace_invalid_doc_id_characters. This was a poor design choice
# and we don't want this in OpenSearch; whatever restrictions there
# may be on indexed chunk ID should have no bearing on the chunk's
# document ID field, even if document ID is an argument to the chunk
# ID. Deliberately choose to use the real doc ID supplied to this
# function.
if vespa_document_id in sanitized_to_original_doc_id_mapping:
logger.warning(
f"Migration warning: Vespa document ID {vespa_document_id} does not match the document ID supplied "
f"{sanitized_to_original_doc_id_mapping[vespa_document_id]}. "
"The Vespa ID will be discarded."
)
document_id = sanitized_to_original_doc_id_mapping.get(
vespa_document_id, vespa_document_id
)
# This should exist; fail loudly if it does not.
chunk_index: int = vespa_chunk[CHUNK_ID]
# This should exist; fail loudly if it does not.
chunk_index: int = vespa_chunk[CHUNK_ID]
title: str | None = vespa_chunk.get(TITLE)
# WARNING: Should supply format.tensors=short-value to the Vespa client
# in order to get a supported format for the tensors.
title_vector: list[float] | None = _extract_title_vector(
vespa_chunk.get(TITLE_EMBEDDING)
)
# This should exist; fail loudly if it does not.
content: str = vespa_chunk[CONTENT]
if not content:
raise ValueError("Missing content in Vespa chunk.")
# This should exist; fail loudly if it does not.
# WARNING: Should supply format.tensors=short-value to the Vespa client
# in order to get a supported format for the tensors.
content_vector: list[float] = _extract_content_vector(vespa_chunk[EMBEDDINGS])
if not content_vector:
raise ValueError("Missing content_vector in Vespa chunk.")
# This should exist; fail loudly if it does not.
source_type: str = vespa_chunk[SOURCE_TYPE]
if not source_type:
raise ValueError("Missing source_type in Vespa chunk.")
metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST)
_raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT)
last_updated: datetime | None = (
datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc)
if _raw_doc_updated_at is not None
else None
)
hidden: bool = vespa_chunk.get(HIDDEN, False)
# This should exist; fail loudly if it does not.
global_boost: int = vespa_chunk[BOOST]
# This should exist; fail loudly if it does not.
semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER]
if not semantic_identifier:
raise ValueError("Missing semantic_identifier in Vespa chunk.")
image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME)
source_links: str | None = vespa_chunk.get(SOURCE_LINKS)
blurb: str = vespa_chunk.get(BLURB, "")
doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "")
chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "")
metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX)
document_sets: list[str] | None = (
_transform_vespa_document_sets_to_opensearch_document_sets(
vespa_chunk.get(DOCUMENT_SETS)
title: str | None = vespa_chunk.get(TITLE)
# WARNING: Should supply format.tensors=short-value to the Vespa
# client in order to get a supported format for the tensors.
title_vector: list[float] | None = _extract_title_vector(
vespa_chunk.get(TITLE_EMBEDDING)
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
# This should exist; fail loudly if it does not; this function will
# raise in that event.
is_public, acl_list = _transform_vespa_acl_to_opensearch_acl(
vespa_chunk.get(ACCESS_CONTROL_LIST)
)
chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID)
if MULTI_TENANT:
if not chunk_tenant_id:
# This should exist; fail loudly if it does not.
content: str = vespa_chunk[CONTENT]
if not content:
raise ValueError(
"Missing tenant_id in Vespa chunk in a multi-tenant environment."
f"Missing content in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
)
if chunk_tenant_id != tenant_state.tenant_id:
# This should exist; fail loudly if it does not.
# WARNING: Should supply format.tensors=short-value to the Vespa
# client in order to get a supported format for the tensors.
content_vector: list[float] = _extract_content_vector(
vespa_chunk[EMBEDDINGS]
)
if not content_vector:
raise ValueError(
f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}"
f"Missing content_vector in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
)
opensearch_chunk = DocumentChunk(
# We deliberately choose to use the doc ID supplied to this function
# over the Vespa doc ID.
document_id=document_id,
chunk_index=chunk_index,
title=title,
title_vector=title_vector,
content=content,
content_vector=content_vector,
source_type=source_type,
metadata_list=metadata_list,
last_updated=last_updated,
public=is_public,
access_control_list=acl_list,
hidden=hidden,
global_boost=global_boost,
semantic_identifier=semantic_identifier,
image_file_id=image_file_id,
source_links=source_links,
blurb=blurb,
doc_summary=doc_summary,
chunk_context=chunk_context,
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,
)
# This should exist; fail loudly if it does not.
source_type: str = vespa_chunk[SOURCE_TYPE]
if not source_type:
raise ValueError(
f"Missing source_type in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
)
result.append(opensearch_chunk)
metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST)
return result
_raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT)
last_updated: datetime | None = (
datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc)
if _raw_doc_updated_at is not None
else None
)
hidden: bool = vespa_chunk.get(HIDDEN, False)
# This should exist; fail loudly if it does not.
global_boost: int = vespa_chunk[BOOST]
# This should exist; fail loudly if it does not.
semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER]
if not semantic_identifier:
raise ValueError(
f"Missing semantic_identifier in Vespa chunk with document ID {vespa_document_id} and chunk "
f"index {chunk_index}."
)
image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME)
source_links: str | None = vespa_chunk.get(SOURCE_LINKS)
blurb: str = vespa_chunk.get(BLURB, "")
doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "")
chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "")
metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX)
document_sets: list[str] | None = (
_transform_vespa_document_sets_to_opensearch_document_sets(
vespa_chunk.get(DOCUMENT_SETS)
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
is_public, acl_list = _transform_vespa_acl_to_opensearch_acl(
vespa_chunk.get(ACCESS_CONTROL_LIST)
)
if not is_public and not acl_list:
logger.warning(
f"Migration warning: Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index} has no "
"public ACL and no access control list. This does not make sense as it implies the document is never "
"searchable. Continuing with the migration..."
)
chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID)
if MULTI_TENANT:
if not chunk_tenant_id:
raise ValueError(
"Missing tenant_id in Vespa chunk in a multi-tenant environment."
)
if chunk_tenant_id != tenant_state.tenant_id:
raise ValueError(
f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}"
)
opensearch_chunk = DocumentChunk(
# We deliberately choose to use the doc ID supplied to this function
# over the Vespa doc ID.
document_id=document_id,
chunk_index=chunk_index,
title=title,
title_vector=title_vector,
content=content,
content_vector=content_vector,
source_type=source_type,
metadata_list=metadata_list,
last_updated=last_updated,
public=is_public,
access_control_list=acl_list,
hidden=hidden,
global_boost=global_boost,
semantic_identifier=semantic_identifier,
image_file_id=image_file_id,
source_links=source_links,
blurb=blurb,
doc_summary=doc_summary,
chunk_context=chunk_context,
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,
)
result.append(opensearch_chunk)
except Exception:
traceback.print_exc()
logger.exception(
f"Migration error: Error transforming Vespa chunk with document ID {vespa_chunk.get(DOCUMENT_ID)} "
f"and chunk index {vespa_chunk.get(CHUNK_ID)} into an OpenSearch chunk. Continuing with "
"the migration..."
)
errored_chunks.append(vespa_chunk)
return result, errored_chunks

View File

@@ -0,0 +1,8 @@
"""Celery tasks for connector pruning."""
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
connector_pruning_generator_task,
)
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]

View File

@@ -523,6 +523,7 @@ def connector_pruning_generator_task(
redis_connector,
lock,
r,
timeout_seconds=JOB_TIMEOUT,
)
# a list of docs in the source

View File

@@ -10,10 +10,12 @@ from celery import Task
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
@@ -37,6 +39,7 @@ from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.utils import store_user_file_plaintext
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
@@ -163,6 +166,132 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
def _process_user_file_without_vector_db(
uf: UserFile,
documents: list[Document],
db_session: Session,
) -> None:
"""Process a user file when the vector DB is disabled.
Extracts raw text and computes a token count, stores the plaintext in
the file store, and marks the file as COMPLETED. Skips embedding and
the indexing pipeline entirely.
"""
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_tokenizer_encode_func
# Combine section text from all document sections
combined_text = " ".join(
section.text for doc in documents for section in doc.sections if section.text
)
# Compute token count using the user's default LLM tokenizer
try:
llm = get_default_llm()
encode = get_llm_tokenizer_encode_func(llm)
token_count: int | None = len(encode(combined_text))
except Exception:
task_logger.warning(
f"_process_user_file_without_vector_db - "
f"Failed to compute token count for {uf.id}, falling back to None"
)
token_count = None
# Persist plaintext for fast FileReaderTool loads
store_user_file_plaintext(
user_file_id=uf.id,
plaintext_content=combined_text,
)
# Update the DB record
if uf.status != UserFileStatus.DELETING:
uf.status = UserFileStatus.COMPLETED
uf.token_count = token_count
uf.chunk_count = 0 # no chunks without vector DB
uf.last_project_sync_at = datetime.datetime.now(datetime.timezone.utc)
db_session.add(uf)
db_session.commit()
task_logger.info(
f"_process_user_file_without_vector_db - "
f"Completed id={uf.id} tokens={token_count}"
)
def _process_user_file_with_indexing(
uf: UserFile,
user_file_id: str,
documents: list[Document],
tenant_id: str,
db_session: Session,
) -> None:
"""Process a user file through the full indexing pipeline (vector DB path)."""
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
search_settings_list = get_active_search_settings_list(db_session)
current_search_settings = next(
(ss for ss in search_settings_list if ss.status.is_current()),
None,
)
if current_search_settings is None:
raise RuntimeError(
f"_process_user_file_with_indexing - "
f"No current search settings found for tenant={tenant_id}"
)
adapter = UserFileIndexingAdapter(
tenant_id=tenant_id,
db_session=db_session,
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=current_search_settings,
)
document_indices = get_all_document_indices(
current_search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
)
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
document_indices=document_indices,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
document_batch=documents,
request_id=None,
adapter=adapter,
)
task_logger.info(
f"_process_user_file_with_indexing - "
f"Indexing pipeline completed ={index_pipeline_result}"
)
if (
index_pipeline_result.failures
or index_pipeline_result.total_docs != len(documents)
or index_pipeline_result.total_chunks == 0
):
task_logger.error(
f"_process_user_file_with_indexing - "
f"Indexing pipeline failed id={user_file_id}"
)
if uf.status != UserFileStatus.DELETING:
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
@@ -205,97 +334,34 @@ def process_single_user_file(
connector = LocalFileConnector(
file_locations=[uf.file_id],
file_names=[uf.name] if uf.name else None,
zip_metadata={},
)
connector.load_credentials({})
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
search_settings_list = get_active_search_settings_list(db_session)
current_search_settings = next(
(
search_settings_instance
for search_settings_instance in search_settings_list
if search_settings_instance.status.is_current()
),
None,
)
if current_search_settings is None:
raise RuntimeError(
f"process_single_user_file - No current search settings found for tenant={tenant_id}"
)
try:
for batch in connector.load_from_state():
documents.extend(
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
)
adapter = UserFileIndexingAdapter(
tenant_id=tenant_id,
db_session=db_session,
)
# Set up indexing pipeline components
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=current_search_settings,
)
# This flow is for indexing so we get all indices.
document_indices = get_all_document_indices(
current_search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
)
# update the doument id to userfile id in the documents
# update the document id to userfile id in the documents
for document in documents:
document.id = str(user_file_id)
document.source = DocumentSource.USER_FILE
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
document_indices=document_indices,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
document_batch=documents,
request_id=None,
adapter=adapter,
)
task_logger.info(
f"process_single_user_file - Indexing pipeline completed ={index_pipeline_result}"
)
if (
index_pipeline_result.failures
or index_pipeline_result.total_docs != len(documents)
or index_pipeline_result.total_chunks == 0
):
task_logger.error(
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
if DISABLE_VECTOR_DB:
_process_user_file_without_vector_db(
uf=uf,
documents=documents,
db_session=db_session,
)
else:
_process_user_file_with_indexing(
uf=uf,
user_file_id=user_file_id,
documents=documents,
tenant_id=tenant_id,
db_session=db_session,
)
# don't update the status if the user file is being deleted
# Re-fetch to avoid mypy error
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if (
current_user_file
and current_user_file.status != UserFileStatus.DELETING
):
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
except Exception as e:
task_logger.exception(
@@ -409,28 +475,6 @@ def process_single_user_file_delete(
return None
try:
with get_session_with_current_tenant() as db_session:
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
# This flow is for deletion so we get all indices.
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
index_name = active_search_settings.primary.index_name
selection = f"{index_name}.document_id=='{user_file_id}'"
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
@@ -438,22 +482,43 @@ def process_single_user_file_delete(
)
return None
# 1) Delete Vespa chunks for the document
chunk_count = 0
if user_file.chunk_count is None or user_file.chunk_count == 0:
chunk_count = _get_document_chunk_count(
index_name=index_name,
selection=selection,
)
else:
chunk_count = user_file.chunk_count
# 1) Delete vector DB chunks (skip when disabled)
if not DISABLE_VECTOR_DB:
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
for retry_document_index in retry_document_indices:
retry_document_index.delete_single(
doc_id=user_file_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
active_search_settings = get_active_search_settings(db_session)
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
index_name = active_search_settings.primary.index_name
selection = f"{index_name}.document_id=='{user_file_id}'"
chunk_count = 0
if user_file.chunk_count is None or user_file.chunk_count == 0:
chunk_count = _get_document_chunk_count(
index_name=index_name,
selection=selection,
)
else:
chunk_count = user_file.chunk_count
for retry_document_index in retry_document_indices:
retry_document_index.delete_single(
doc_id=user_file_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
file_store = get_default_file_store()
@@ -565,27 +630,6 @@ def process_single_user_file_project_sync(
try:
with get_session_with_current_tenant() as db_session:
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
# This flow is for updates so we get all indices.
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
@@ -593,15 +637,35 @@ def process_single_user_file_project_sync(
)
return None
project_ids = [project.id for project in user_file.projects]
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
# Sync project metadata to vector DB (skip when disabled)
if not DISABLE_VECTOR_DB:
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
project_ids = [project.id for project in user_file.projects]
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
task_logger.info(
f"process_single_user_file_project_sync - User file id={user_file_id}"

View File

@@ -677,7 +677,6 @@ def connector_document_extraction(
logger.debug(f"Indexing batch of documents: {batch_description}")
memory_tracer.increment_and_maybe_trace()
# cc4a
if processing_mode == ProcessingMode.FILE_SYSTEM:
# File system only - write directly to persistent storage,
# skip chunking/embedding/Vespa but still track documents in DB
@@ -817,17 +816,19 @@ def connector_document_extraction(
if processing_mode == ProcessingMode.FILE_SYSTEM:
creator_id = index_attempt.connector_credential_pair.creator_id
if creator_id:
source_value = db_connector.source.value
app.send_task(
OnyxCeleryTask.SANDBOX_FILE_SYNC,
kwargs={
"user_id": str(creator_id),
"tenant_id": tenant_id,
"source": source_value,
},
queue=OnyxCeleryQueues.SANDBOX,
)
logger.info(
f"Triggered sandbox file sync for user {creator_id} "
f"after indexing complete"
f"source={source_value} after indexing complete"
)
except Exception as e:

View File

@@ -9,10 +9,8 @@ Summaries are stored as `ChatMessage` records with two key fields:
- `parent_message_id` → last message when compression triggered (places summary in the tree)
- `last_summarized_message_id` → pointer to an older message up the chain (the cutoff). Messages after this are kept verbatim.
**Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred.
### Timestamp-Based Ordering
Messages are filtered by `time_sent` (not ID) so the logic remains intact if IDs are changed to UUIDs in the future.
**Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred. It's only back-pointed to by the
branch which it applies to. All of this is necessary because we keep the last few messages verbatim and also to support branching logic.
### Progressive Summarization
Subsequent compressions incorporate the existing summary text + new messages, preventing information loss in very long conversations.
@@ -26,10 +24,11 @@ Context window breakdown:
- `max_context_tokens` — LLM's total context window
- `reserved_tokens` — space for system prompt, tools, files, etc.
- Available for chat history = `max_context_tokens - reserved_tokens`
Note: If there is a lot of reserved tokens, chat compression may happen fairly frequently which is costly, slow, and leads to a bad user experience. Possible area of future improvement.
Configurable ratios:
- `COMPRESSION_TRIGGER_RATIO` (default 0.75) — compress when chat history exceeds this ratio of available space
- `RECENT_MESSAGES_RATIO` (default 0.25) — portion of chat history to keep verbatim when compressing
- `RECENT_MESSAGES_RATIO` (default 0.2) — portion of chat history to keep verbatim when compressing
## Flow

View File

@@ -3,32 +3,26 @@ from collections.abc import Callable
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from onyx.auth.users import is_user_admin
from onyx.chat.models import ChatHistoryResult
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.chat import create_chat_session
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_or_create_root_message
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
from onyx.db.llm import fetch_existing_doc_sets
from onyx.db.llm import fetch_existing_tools
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import check_project_ownership
from onyx.file_processing.extract_file_text import extract_file_text
@@ -45,9 +39,6 @@ from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
@@ -276,70 +267,6 @@ def extract_headers(
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=RecencyBiasSetting.BASE_DECAY,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
# Use the first prompt from the override config for embedded prompt fields
first_prompt = persona_config.prompts[0]
persona.system_prompt = first_prompt.system_prompt
persona.task_prompt = first_prompt.task_prompt
persona.datetime_aware = first_prompt.datetime_aware
persona.tools = []
if persona_config.custom_tools_openapi:
from onyx.chat.emitter import get_default_emitter
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=schema,
emitter=get_default_emitter(),
),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona
def process_kg_commands(
message: str, persona_name: str, tenant_id: str, db_session: Session # noqa: ARG001
) -> None:
@@ -502,15 +429,22 @@ def convert_chat_history(
additional_context: str | None,
token_counter: Callable[[str], int],
tool_id_to_name_map: dict[int, str],
) -> list[ChatMessageSimple]:
) -> ChatHistoryResult:
"""Convert ChatMessage history to ChatMessageSimple format.
For user messages: includes attached files (images attached to message, text files as separate messages)
For assistant messages with tool calls: creates ONE ASSISTANT message with tool_calls array,
followed by N TOOL_CALL_RESPONSE messages (OpenAI parallel tool calling format)
For assistant messages without tool calls: creates a simple ASSISTANT message
Every injected text-file message is tagged with ``file_id`` and its
metadata is collected in ``ChatHistoryResult.all_injected_file_metadata``.
After context-window truncation, callers compare surviving ``file_id`` tags
against this map to discover "forgotten" files and provide their metadata
to the FileReaderTool.
"""
simple_messages: list[ChatMessageSimple] = []
all_injected_file_metadata: dict[str, FileToolMetadata] = {}
# Create a mapping of file IDs to loaded files for quick lookup
file_map = {str(f.file_id): f for f in files}
@@ -539,7 +473,9 @@ def convert_chat_history(
# Text files (DOC, PLAIN_TEXT, CSV) are added as separate messages
text_files.append(loaded_file)
# Add text files as separate messages before the user message
# Add text files as separate messages before the user message.
# Each message is tagged with ``file_id`` so that forgotten files
# can be detected after context-window truncation.
for text_file in text_files:
file_text = text_file.content_text or ""
filename = text_file.filename
@@ -554,8 +490,14 @@ def convert_chat_history(
token_count=text_file.token_count,
message_type=MessageType.USER,
image_files=None,
file_id=text_file.file_id,
)
)
all_injected_file_metadata[text_file.file_id] = FileToolMetadata(
file_id=text_file.file_id,
filename=filename or "unknown",
approx_char_count=len(file_text),
)
# Sum token counts from image files (excluding project image files)
image_token_count = (
@@ -664,32 +606,41 @@ def convert_chat_history(
f"Invalid message type when constructing simple history: {chat_message.message_type}"
)
return simple_messages
return ChatHistoryResult(
simple_messages=simple_messages,
all_injected_file_metadata=all_injected_file_metadata,
)
def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str | None:
"""Get the custom agent prompt from persona or project instructions.
"""Get the custom agent prompt from persona or project instructions. If it's replacing the base system prompt,
it does not count as a custom agent prompt (logic exists later also to drop it in this case).
Chat Sessions in Projects that are using a custom agent will retain the custom agent prompt.
Priority: persona.system_prompt > chat_session.project.instructions > None
Priority: persona.system_prompt (if not default Agent) > chat_session.project.instructions
# NOTE: Logic elsewhere allows saving empty strings for potentially other purposes but for constructing the prompts
# we never want to return an empty string for a prompt so it's translated into an explicit None.
Args:
persona: The Persona object
chat_session: The ChatSession object
Returns:
The custom agent prompt string, or None if neither persona nor project has one
The prompt to use for the custom Agent part of the prompt.
"""
# Not considered a custom agent if it's the default behavior persona
if persona.id == DEFAULT_PERSONA_ID:
return None
# If using a custom Agent, always respect its prompt, even if in a Project, and even if it's an empty custom prompt.
if persona.id != DEFAULT_PERSONA_ID:
# Logic exists later also to drop it in this case but this is strictly correct anyhow.
if persona.replace_base_system_prompt:
return None
return persona.system_prompt or None
if persona.system_prompt:
return persona.system_prompt
elif chat_session.project and chat_session.project.instructions:
# If in a project and using the default Agent, respect the project instructions.
if chat_session.project and chat_session.project.instructions:
return chat_session.project.instructions
else:
return None
return None
def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool:

View File

@@ -17,20 +17,26 @@ from onyx.configs.chat_configs import COMPRESSION_TRIGGER_RATIO
from onyx.configs.constants import MessageType
from onyx.db.models import ChatMessage
from onyx.llm.interfaces import LLM
from onyx.llm.models import AssistantMessage
from onyx.llm.models import ChatCompletionMessage
from onyx.llm.models import SystemMessage
from onyx.llm.models import UserMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_PROMPT
from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK
from onyx.prompts.compression_prompts import PROGRESSIVE_USER_REMINDER
from onyx.prompts.compression_prompts import SUMMARIZATION_CUTOFF_MARKER
from onyx.prompts.compression_prompts import SUMMARIZATION_PROMPT
from onyx.prompts.compression_prompts import USER_FINAL_REMINDER
from onyx.prompts.compression_prompts import USER_REMINDER
from onyx.tracing.framework.create import ensure_trace
from onyx.tracing.llm_utils import llm_generation_span
from onyx.tracing.llm_utils import record_llm_response
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Ratio of available context to allocate for recent messages after compression
RECENT_MESSAGES_RATIO = 0.25
RECENT_MESSAGES_RATIO = 0.2
class CompressionResult(BaseModel):
@@ -187,6 +193,11 @@ def get_messages_to_summarize(
recent_messages.insert(0, msg)
tokens_used += msg_tokens
# Ensure cutoff is right before a user message by moving any leading
# non-user messages from recent_messages to older_messages
while recent_messages and recent_messages[0].message_type != MessageType.USER:
recent_messages.pop(0)
# Everything else gets summarized
recent_ids = {m.id for m in recent_messages}
older_messages = [m for m in messages if m.id not in recent_ids]
@@ -196,31 +207,47 @@ def get_messages_to_summarize(
)
def format_messages_for_summary(
def _build_llm_messages_for_summarization(
messages: list[ChatMessage],
tool_id_to_name: dict[int, str],
) -> str:
"""Format messages into a string for the summarization prompt.
) -> list[UserMessage | AssistantMessage]:
"""Convert ChatMessage objects to LLM message format for summarization.
Tool call messages are formatted compactly to save tokens.
This is intentionally different from translate_history_to_llm_format in llm_step.py:
- Compacts tool calls to "[Used tools: tool1, tool2]" to save tokens in summaries
- Skips TOOL_CALL_RESPONSE messages entirely (tool usage captured in assistant message)
- No image/multimodal handling (summaries are text-only)
- No caching or LLMConfig-specific behavior needed
"""
formatted = []
result: list[UserMessage | AssistantMessage] = []
for msg in messages:
# Format assistant messages with tool calls compactly
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
tool_names = [
tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls
]
formatted.append(f"[assistant used tools: {', '.join(tool_names)}]")
# Skip empty messages
if not msg.message:
continue
# Handle assistant messages with tool calls compactly
if msg.message_type == MessageType.ASSISTANT:
if msg.tool_calls:
tool_names = [
tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls
]
result.append(
AssistantMessage(content=f"[Used tools: {', '.join(tool_names)}]")
)
else:
result.append(AssistantMessage(content=msg.message))
continue
# Skip tool call response messages - tool calls are captured above via assistant messages
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
continue
role = msg.message_type.value
formatted.append(f"[{role}]: {msg.message}")
return "\n\n".join(formatted)
# Handle user messages
if msg.message_type == MessageType.USER:
result.append(UserMessage(content=msg.message))
return result
def generate_summary(
@@ -236,6 +263,9 @@ def generate_summary(
The cutoff marker tells the LLM to summarize only older messages,
while using recent messages as context to inform what's important.
Messages are sent as separate UserMessage/AssistantMessage objects rather
than being concatenated into a single message.
Args:
older_messages: Messages to compress into summary (before cutoff)
recent_messages: Messages kept verbatim (after cutoff, for context only)
@@ -246,37 +276,54 @@ def generate_summary(
Returns:
Summary text
"""
older_messages_str = format_messages_for_summary(older_messages, tool_id_to_name)
recent_messages_str = format_messages_for_summary(recent_messages, tool_id_to_name)
# Build user prompt with cutoff marker
# Build system prompt
system_content = SUMMARIZATION_PROMPT
if existing_summary:
# Progressive summarization: include existing summary
user_prompt = PROGRESSIVE_SUMMARY_PROMPT.format(
existing_summary=existing_summary
# Progressive summarization: append existing summary to system prompt
system_content += PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK.format(
previous_summary=existing_summary
)
user_prompt += f"\n\n{older_messages_str}"
final_reminder = PROGRESSIVE_USER_REMINDER
else:
# Initial summarization
user_prompt = older_messages_str
final_reminder = USER_FINAL_REMINDER
final_reminder = USER_REMINDER
# Add cutoff marker and recent messages as context
user_prompt += f"\n\n{SUMMARIZATION_CUTOFF_MARKER}"
if recent_messages_str:
user_prompt += f"\n\n{recent_messages_str}"
# Convert messages to LLM format (using compression-specific conversion)
older_llm_messages = _build_llm_messages_for_summarization(
older_messages, tool_id_to_name
)
recent_llm_messages = _build_llm_messages_for_summarization(
recent_messages, tool_id_to_name
)
# Build message list with separate messages
input_messages: list[ChatCompletionMessage] = [
SystemMessage(content=system_content),
]
# Add older messages (to be summarized)
input_messages.extend(older_llm_messages)
# Add cutoff marker as a user message
input_messages.append(UserMessage(content=SUMMARIZATION_CUTOFF_MARKER))
# Add recent messages (for context only)
input_messages.extend(recent_llm_messages)
# Add final reminder
user_prompt += f"\n\n{final_reminder}"
input_messages.append(UserMessage(content=final_reminder))
response = llm.invoke(
[
SystemMessage(content=SUMMARIZATION_PROMPT),
UserMessage(content=user_prompt),
]
)
return response.choice.message.content or ""
with llm_generation_span(
llm=llm,
flow="chat_history_summarization",
input_messages=input_messages,
) as span_generation:
response = llm.invoke(input_messages)
record_llm_response(span_generation, response)
content = response.choice.message.content
if not (content and content.strip()):
raise ValueError("LLM returned empty summary")
return content.strip()
def compress_chat_history(
@@ -292,6 +339,19 @@ def compress_chat_history(
The summary message's parent_message_id points to the last message in
chat_history, making it branch-aware via the tree structure.
Note: This takes the entire chat history as input, splits it into older
messages (to summarize) and recent messages (kept verbatim within the
token budget), generates a summary of the older part, and persists the
new summary message with its parent set to the last message in history.
Past summary is taken into context (progressive summarization): we find
at most one existing summary for this branch. If present, only messages
after that summary's last_summarized_message_id are considered; the
existing summary text is passed into the LLM so the new summary
incorporates it instead of summarizing from scratch.
For more details, see the COMPRESSION.md file.
Args:
db_session: Database session
chat_history: Branch-aware list of messages
@@ -305,74 +365,84 @@ def compress_chat_history(
if not chat_history:
return CompressionResult(summary_created=False, messages_summarized=0)
chat_session_id = chat_history[0].chat_session_id
logger.info(
f"Starting compression for session {chat_history[0].chat_session_id}, "
f"Starting compression for session {chat_session_id}, "
f"history_len={len(chat_history)}, tokens_for_recent={compression_params.tokens_for_recent}"
)
try:
# Find existing summary for this branch
existing_summary = find_summary_for_branch(db_session, chat_history)
with ensure_trace(
"chat_history_compression",
group_id=str(chat_session_id),
metadata={
"tenant_id": get_current_tenant_id(),
"chat_session_id": str(chat_session_id),
},
):
try:
# Find existing summary for this branch
existing_summary = find_summary_for_branch(db_session, chat_history)
# Get messages to summarize
summary_content = get_messages_to_summarize(
chat_history,
existing_summary,
tokens_for_recent=compression_params.tokens_for_recent,
)
# Get messages to summarize
summary_content = get_messages_to_summarize(
chat_history,
existing_summary,
tokens_for_recent=compression_params.tokens_for_recent,
)
if not summary_content.older_messages:
logger.debug("No messages to summarize, skipping compression")
return CompressionResult(summary_created=False, messages_summarized=0)
if not summary_content.older_messages:
logger.debug("No messages to summarize, skipping compression")
return CompressionResult(summary_created=False, messages_summarized=0)
# Generate summary (incorporate existing summary if present)
existing_summary_text = existing_summary.message if existing_summary else None
summary_text = generate_summary(
older_messages=summary_content.older_messages,
recent_messages=summary_content.recent_messages,
llm=llm,
tool_id_to_name=tool_id_to_name,
existing_summary=existing_summary_text,
)
# Generate summary (incorporate existing summary if present)
existing_summary_text = (
existing_summary.message if existing_summary else None
)
summary_text = generate_summary(
older_messages=summary_content.older_messages,
recent_messages=summary_content.recent_messages,
llm=llm,
tool_id_to_name=tool_id_to_name,
existing_summary=existing_summary_text,
)
# Calculate token count for the summary
tokenizer = get_tokenizer(None, None)
summary_token_count = len(tokenizer.encode(summary_text))
logger.debug(
f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..."
)
# Calculate token count for the summary
tokenizer = get_tokenizer(None, None)
summary_token_count = len(tokenizer.encode(summary_text))
logger.debug(
f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..."
)
# Create new summary as a ChatMessage
# Parent is the last message in history - this makes the summary branch-aware
summary_message = ChatMessage(
chat_session_id=chat_history[0].chat_session_id,
message_type=MessageType.ASSISTANT,
message=summary_text,
token_count=summary_token_count,
parent_message_id=chat_history[-1].id,
last_summarized_message_id=summary_content.older_messages[-1].id,
)
db_session.add(summary_message)
db_session.commit()
# Create new summary as a ChatMessage
# Parent is the last message in history - this makes the summary branch-aware
summary_message = ChatMessage(
chat_session_id=chat_session_id,
message_type=MessageType.ASSISTANT,
message=summary_text,
token_count=summary_token_count,
parent_message_id=chat_history[-1].id,
last_summarized_message_id=summary_content.older_messages[-1].id,
)
db_session.add(summary_message)
db_session.commit()
logger.info(
f"Compressed {len(summary_content.older_messages)} messages into summary "
f"(session_id={chat_history[0].chat_session_id}, "
f"summary_tokens={summary_token_count})"
)
logger.info(
f"Compressed {len(summary_content.older_messages)} messages into summary "
f"(session_id={chat_session_id}, "
f"summary_tokens={summary_token_count})"
)
return CompressionResult(
summary_created=True,
messages_summarized=len(summary_content.older_messages),
)
return CompressionResult(
summary_created=True,
messages_summarized=len(summary_content.older_messages),
)
except Exception as e:
logger.exception(
f"Compression failed for session {chat_history[0].chat_session_id}: {e}"
)
db_session.rollback()
return CompressionResult(
summary_created=False,
messages_summarized=0,
error=str(e),
)
except Exception as e:
logger.exception(f"Compression failed for session {chat_session_id}: {e}")
db_session.rollback()
return CompressionResult(
summary_created=False,
messages_summarized=0,
error=str(e),
)

View File

@@ -1,5 +1,7 @@
import json
import time
from collections.abc import Callable
from typing import Literal
from sqlalchemy.orm import Session
@@ -14,6 +16,7 @@ from onyx.chat.llm_step import extract_tool_calls_from_response_text
from onyx.chat.llm_step import run_llm_step
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import LlmStepResult
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ToolCallSimple
@@ -27,12 +30,14 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
from onyx.db.models import Persona
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import model_needs_formatting_reenabled
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.server.query_and_chat.placement import Placement
@@ -43,12 +48,14 @@ from onyx.server.query_and_chat.streaming_models import TopLevelBranching
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
from onyx.tools.interface import Tool
from onyx.tools.models import MemoryToolResponseSnapshot
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -60,6 +67,28 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
if model_provider not in {
LlmProviderNames.BEDROCK,
LlmProviderNames.BEDROCK_CONVERSE,
}:
return False
return any(
(
msg.message_type == MessageType.ASSISTANT
and msg.tool_calls
and len(msg.tool_calls) > 0
)
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
for msg in simple_chat_history
)
def _try_fallback_tool_extraction(
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
@@ -179,6 +208,35 @@ def _build_project_file_citation_mapping(
return citation_mapping
def _build_project_message(
project_files: ExtractedProjectFiles | None,
token_counter: Callable[[str], int] | None,
) -> list[ChatMessageSimple]:
"""Build messages for project / tool-backed files.
Returns up to two messages:
1. The full-text project files message (if project_file_texts is populated).
2. A lightweight metadata message for files the LLM should access via the
FileReaderTool (e.g. oversized chat-attached files or project files that
don't fit in context).
"""
if not project_files:
return []
messages: list[ChatMessageSimple] = []
if project_files.project_file_texts:
messages.append(
_create_project_files_message(project_files, token_counter=None)
)
if project_files.file_metadata_for_tool and token_counter:
messages.append(
_create_file_tool_metadata_message(
project_files.file_metadata_for_tool, token_counter
)
)
return messages
def construct_message_history(
system_prompt: ChatMessageSimple | None,
custom_agent_prompt: ChatMessageSimple | None,
@@ -187,6 +245,8 @@ def construct_message_history(
project_files: ExtractedProjectFiles | None,
available_tokens: int,
last_n_user_messages: int | None = None,
token_counter: Callable[[str], int] | None = None,
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
) -> list[ChatMessageSimple]:
if last_n_user_messages is not None:
if last_n_user_messages <= 0:
@@ -194,13 +254,17 @@ def construct_message_history(
"filtering chat history by last N user messages must be a value greater than 0"
)
# Build the project / file-metadata messages up front so we can use their
# actual token counts for the budget.
project_messages = _build_project_message(project_files, token_counter)
project_messages_tokens = sum(m.token_count for m in project_messages)
history_token_budget = available_tokens
history_token_budget -= system_prompt.token_count if system_prompt else 0
history_token_budget -= (
custom_agent_prompt.token_count if custom_agent_prompt else 0
)
if project_files:
history_token_budget -= project_files.total_token_count
history_token_budget -= project_messages_tokens
history_token_budget -= reminder_message.token_count if reminder_message else 0
if history_token_budget < 0:
@@ -214,11 +278,7 @@ def construct_message_history(
result = [system_prompt] if system_prompt else []
if custom_agent_prompt:
result.append(custom_agent_prompt)
if project_files and project_files.project_file_texts:
project_message = _create_project_files_message(
project_files, token_counter=None
)
result.append(project_message)
result.extend(project_messages)
if reminder_message:
result.append(reminder_message)
return result
@@ -277,8 +337,11 @@ def construct_message_history(
# Calculate remaining budget for history before the last user message
remaining_budget = history_token_budget - required_tokens
# Truncate history_before_last_user from the top to fit in remaining budget
# Truncate history_before_last_user from the top to fit in remaining budget.
# Track dropped file messages so we can provide their metadata to the
# FileReaderTool instead.
truncated_history_before: list[ChatMessageSimple] = []
dropped_file_ids: list[str] = []
current_token_count = 0
for msg in reversed(history_before_last_user):
@@ -287,9 +350,67 @@ def construct_message_history(
truncated_history_before.insert(0, msg)
current_token_count += msg.token_count
else:
# Can't fit this message, stop truncating
# Can't fit this message, stop truncating.
# This message and everything older is dropped.
break
# Collect file_ids from ALL dropped messages (those not in
# truncated_history_before). The truncation loop above keeps the most
# recent messages, so the dropped ones are at the start of the original
# list up to (len(history) - len(kept)).
num_kept = len(truncated_history_before)
for msg in history_before_last_user[: len(history_before_last_user) - num_kept]:
if msg.file_id is not None:
dropped_file_ids.append(msg.file_id)
# Also treat "orphaned" metadata entries as dropped -- these are files
# from messages removed by summary truncation (before convert_chat_history
# ran), so no ChatMessageSimple was ever tagged with their file_id.
if all_injected_file_metadata:
surviving_file_ids = {
msg.file_id for msg in simple_chat_history if msg.file_id is not None
}
for fid in all_injected_file_metadata:
if fid not in surviving_file_ids and fid not in dropped_file_ids:
dropped_file_ids.append(fid)
# Build a forgotten-files metadata message if any file messages were
# dropped AND we have metadata for them (meaning the FileReaderTool is
# available). Reserve tokens for this message in the budget.
forgotten_files_message: ChatMessageSimple | None = None
if dropped_file_ids and all_injected_file_metadata and token_counter:
forgotten_meta = [
all_injected_file_metadata[fid]
for fid in dropped_file_ids
if fid in all_injected_file_metadata
]
if forgotten_meta:
logger.debug(
f"FileReader: building forgotten-files message for "
f"{[(m.file_id, m.filename) for m in forgotten_meta]}"
)
forgotten_files_message = _create_file_tool_metadata_message(
forgotten_meta, token_counter
)
# Shrink the remaining budget. If the metadata message doesn't
# fit we may need to drop more history messages.
remaining_budget -= forgotten_files_message.token_count
while truncated_history_before and current_token_count > remaining_budget:
evicted = truncated_history_before.pop(0)
current_token_count -= evicted.token_count
# If the evicted message is itself a file, add it to the
# forgotten metadata (it's now dropped too).
if (
evicted.file_id is not None
and evicted.file_id in all_injected_file_metadata
and evicted.file_id not in {m.file_id for m in forgotten_meta}
):
forgotten_meta.append(all_injected_file_metadata[evicted.file_id])
# Rebuild the message with the new entry
forgotten_files_message = _create_file_tool_metadata_message(
forgotten_meta, token_counter
)
# Attach project images to the last user message
if project_files and project_files.project_image_files:
existing_images = last_user_message.image_files or []
@@ -302,7 +423,7 @@ def construct_message_history(
# Build the final message list according to README ordering:
# [system], [history_before_last_user], [custom_agent], [project_files],
# [last_user_message], [messages_after_last_user], [reminder]
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
result = [system_prompt] if system_prompt else []
# 1. Add truncated history before last user message
@@ -312,26 +433,52 @@ def construct_message_history(
if custom_agent_prompt:
result.append(custom_agent_prompt)
# 3. Add project files message (inserted before last user message)
if project_files and project_files.project_file_texts:
project_message = _create_project_files_message(
project_files, token_counter=None
)
result.append(project_message)
# 3. Add project files / file-metadata messages (inserted before last user message)
result.extend(project_messages)
# 4. Add last user message (with project images attached)
# 4. Add forgotten-files metadata (right before the user's question)
if forgotten_files_message:
result.append(forgotten_files_message)
# 5. Add last user message (with project images attached)
result.append(last_user_message)
# 5. Add messages after last user message (tool calls, responses, etc.)
# 6. Add messages after last user message (tool calls, responses, etc.)
result.extend(messages_after_last_user)
# 6. Add reminder message at the very end
# 7. Add reminder message at the very end
if reminder_message:
result.append(reminder_message)
return result
def _create_file_tool_metadata_message(
file_metadata: list[FileToolMetadata],
token_counter: Callable[[str], int],
) -> ChatMessageSimple:
"""Build a lightweight metadata-only message listing files available via FileReaderTool.
Used when files are too large to fit in context and the vector DB is
disabled, so the LLM must use ``read_file`` to inspect them.
"""
lines = [
"You have access to the following files. Use the read_file tool to "
"read sections of any file:"
]
for meta in file_metadata:
lines.append(
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
)
message_content = "\n".join(lines)
return ChatMessageSimple(
message=message_content,
token_count=token_counter(message_content),
message_type=MessageType.USER,
)
def _create_project_files_message(
project_files: ExtractedProjectFiles,
token_counter: Callable[[str], int] | None, # noqa: ARG001
@@ -379,6 +526,8 @@ def run_llm_loop(
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
include_citations: bool = True,
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
inject_memories_in_prompt: bool = True,
) -> None:
with trace(
"run_llm_loop",
@@ -444,6 +593,7 @@ def run_llm_loop(
reasoning_cycles = 0
for llm_cycle_count in range(MAX_LLM_CYCLES):
# Handling tool calls based on cycle count and past cycle conditions
out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1
if forced_tool_id:
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
@@ -455,11 +605,17 @@ def run_llm_loop(
elif out_of_cycles or ran_image_gen:
# Last cycle, no tools allowed, just answer!
tool_choice = ToolChoiceOptions.NONE
final_tools = []
# Bedrock requires tool config in requests that include toolUse/toolResult history.
final_tools = (
tools
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
else []
)
else:
tool_choice = ToolChoiceOptions.AUTO
final_tools = tools
# Handling the system prompt and custom agent prompt
# The section below calculates the available tokens for history a bit more accurately
# now that project files are loaded in.
if persona and persona.replace_base_system_prompt:
@@ -477,18 +633,22 @@ def run_llm_loop(
else:
# If it's an empty string, we assume the user does not want to include it as an empty System message
if default_base_system_prompt:
open_ai_formatting_enabled = model_needs_formatting_reenabled(
llm.config.model_name
prompt_memory_context = (
user_memory_context
if inject_memories_in_prompt
else (
user_memory_context.without_memories()
if user_memory_context
else None
)
)
system_prompt_str = build_system_prompt(
base_system_prompt=default_base_system_prompt,
datetime_aware=persona.datetime_aware if persona else True,
user_memory_context=user_memory_context,
user_memory_context=prompt_memory_context,
tools=tools,
should_cite_documents=should_cite_documents
or always_cite_documents,
open_ai_formatting_enabled=open_ai_formatting_enabled,
)
system_prompt = ChatMessageSimple(
message=system_prompt_str,
@@ -541,7 +701,7 @@ def run_llm_loop(
ChatMessageSimple(
message=reminder_message_text,
token_count=token_counter(reminder_message_text),
message_type=MessageType.USER,
message_type=MessageType.USER_REMINDER,
)
if reminder_message_text
else None
@@ -554,6 +714,8 @@ def run_llm_loop(
reminder_message=reminder_msg,
project_files=project_files,
available_tokens=available_tokens,
token_counter=token_counter,
all_injected_file_metadata=all_injected_file_metadata,
)
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
@@ -645,6 +807,7 @@ def run_llm_loop(
max_concurrent_tools=None,
skip_search_query_expansion=has_called_search_tool,
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
inject_memories_in_prompt=inject_memories_in_prompt,
)
tool_responses = parallel_tool_call_results.tool_responses
citation_mapping = parallel_tool_call_results.updated_citation_mapping
@@ -709,11 +872,44 @@ def run_llm_loop(
):
generated_images = tool_response.rich_response.generated_images
saved_response = (
tool_response.rich_response
if isinstance(tool_response.rich_response, str)
else tool_response.llm_facing_response
)
# Persist memory if this is a memory tool response
memory_snapshot: MemoryToolResponseSnapshot | None = None
if isinstance(tool_response.rich_response, MemoryToolResponse):
persisted_memory_id: int | None = None
if user_memory_context and user_memory_context.user_id:
if tool_response.rich_response.index_to_replace is not None:
memory = update_memory_at_index(
user_id=user_memory_context.user_id,
index=tool_response.rich_response.index_to_replace,
new_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id if memory else None
else:
memory = add_memory(
user_id=user_memory_context.user_id,
memory_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id
operation: Literal["add", "update"] = (
"update"
if tool_response.rich_response.index_to_replace is not None
else "add"
)
memory_snapshot = MemoryToolResponseSnapshot(
memory_text=tool_response.rich_response.memory_text,
operation=operation,
memory_id=persisted_memory_id,
index=tool_response.rich_response.index_to_replace,
)
if memory_snapshot:
saved_response = json.dumps(memory_snapshot.model_dump())
elif isinstance(tool_response.rich_response, str):
saved_response = tool_response.rich_response
else:
saved_response = tool_response.llm_facing_response
tool_call_info = ToolCallInfo(
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message

View File

@@ -36,6 +36,10 @@ from onyx.llm.models import ToolCall
from onyx.llm.models import ToolMessage
from onyx.llm.models import UserMessage
from onyx.llm.prompt_cache.processor import process_with_prompt_cache
from onyx.llm.utils import model_needs_formatting_reenabled
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
@@ -332,26 +336,48 @@ def extract_tool_calls_from_response_text(
# Find all JSON objects in the response text
json_objects = find_all_json_objects(response_text)
tool_calls: list[ToolCallKickoff] = []
tab_index = 0
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
prev_json_obj: dict[str, Any] | None = None
prev_tool_call: tuple[str, dict[str, Any]] | None = None
for json_obj in json_objects:
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
if matched_tool_call:
tool_name, tool_args = matched_tool_call
tool_calls.append(
ToolCallKickoff(
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
tool_name=tool_name,
tool_args=tool_args,
placement=Placement(
turn_index=placement.turn_index,
tab_index=tab_index,
sub_turn_index=placement.sub_turn_index,
),
)
if not matched_tool_call:
continue
# `find_all_json_objects` can return both an outer tool-call object and
# its nested arguments object. If both resolve to the same tool call,
# drop only this nested duplicate artifact.
if (
prev_json_obj is not None
and prev_tool_call is not None
and matched_tool_call == prev_tool_call
and _is_nested_arguments_duplicate(
previous_json_obj=prev_json_obj,
current_json_obj=json_obj,
tool_name_to_def=tool_name_to_def,
)
tab_index += 1
):
continue
matched_tool_calls.append(matched_tool_call)
prev_json_obj = json_obj
prev_tool_call = matched_tool_call
tool_calls: list[ToolCallKickoff] = []
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
tool_calls.append(
ToolCallKickoff(
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
tool_name=tool_name,
tool_args=tool_args,
placement=Placement(
turn_index=placement.turn_index,
tab_index=tab_index,
sub_turn_index=placement.sub_turn_index,
),
)
)
logger.info(
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
@@ -433,6 +459,42 @@ def _try_match_json_to_tool(
return None
def _is_nested_arguments_duplicate(
previous_json_obj: dict[str, Any],
current_json_obj: dict[str, Any],
tool_name_to_def: dict[str, dict],
) -> bool:
"""Detect when current object is the nested args object from previous tool call."""
extracted_args = _extract_nested_arguments_obj(previous_json_obj, tool_name_to_def)
return extracted_args is not None and current_json_obj == extracted_args
def _extract_nested_arguments_obj(
json_obj: dict[str, Any],
tool_name_to_def: dict[str, dict],
) -> dict[str, Any] | None:
# Format 1: {"name": "...", "arguments": {...}} or {"name": "...", "parameters": {...}}
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
args_obj = json_obj.get("arguments", json_obj.get("parameters"))
if isinstance(args_obj, dict):
return args_obj
# Format 2: {"function": {"name": "...", "arguments": {...}}}
if "function" in json_obj and isinstance(json_obj["function"], dict):
function_obj = json_obj["function"]
if "name" in function_obj and function_obj["name"] in tool_name_to_def:
args_obj = function_obj.get("arguments", function_obj.get("parameters"))
if isinstance(args_obj, dict):
return args_obj
# Format 3: {"tool_name": {...arguments...}}
for tool_name in tool_name_to_def:
if tool_name in json_obj and isinstance(json_obj[tool_name], dict):
return json_obj[tool_name]
return None
def translate_history_to_llm_format(
history: list[ChatMessageSimple],
llm_config: LLMConfig,
@@ -451,6 +513,7 @@ def translate_history_to_llm_format(
if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [
MessageType.SYSTEM,
MessageType.USER,
MessageType.USER_REMINDER,
MessageType.ASSISTANT,
MessageType.TOOL_CALL_RESPONSE,
]:
@@ -512,6 +575,16 @@ def translate_history_to_llm_format(
)
messages.append(user_msg_text)
elif msg.message_type == MessageType.USER_REMINDER:
# User reminder messages are wrapped with system-reminder tags
# and converted to UserMessage (LLM APIs don't have a native reminder type)
wrapped_content = f"{SYSTEM_REMINDER_TAG_OPEN}\n{msg.message}\n{SYSTEM_REMINDER_TAG_CLOSE}"
reminder_msg = UserMessage(
role="user",
content=wrapped_content,
)
messages.append(reminder_msg)
elif msg.message_type == MessageType.ASSISTANT:
tool_calls_list: list[ToolCall] | None = None
if msg.tool_calls:
@@ -552,6 +625,17 @@ def translate_history_to_llm_format(
f"Unknown message type {msg.message_type} in history. Skipping message."
)
# Apply model-specific formatting when translating to LLM format (e.g. OpenAI
# reasoning models need CODE_BLOCK_MARKDOWN prefix for correct markdown generation)
if model_needs_formatting_reenabled(llm_config.model_name):
for i, m in enumerate(messages):
if isinstance(m, SystemMessage):
messages[i] = SystemMessage(
role="system",
content=CODE_BLOCK_MARKDOWN + m.content,
)
break
# prompt caching: rely on should_cache in ChatMessageSimple to
# pick the split point for the cacheable prefix and suffix
if last_cacheable_msg_idx != -1:

View File

@@ -1,17 +1,13 @@
from collections.abc import Callable
from collections.abc import Iterator
from enum import Enum
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.constants import MessageType
from onyx.context.search.enums import SearchType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import Packet
@@ -20,54 +16,6 @@ from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
FINISHED = "finished"
class StreamType(Enum):
SUB_QUESTIONS = "sub_questions"
SUB_ANSWER = "sub_answer"
MAIN_ANSWER = "main_answer"
class StreamStopInfo(BaseModel):
stop_reason: StreamStopReason
stream_type: StreamType = StreamType.MAIN_ANSWER
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
return data
class UserKnowledgeFilePacket(BaseModel):
user_files: list[FileDescriptor]
class RelevanceAnalysis(BaseModel):
relevant: bool
content: str | None = None
class DocumentRelevance(BaseModel):
"""Contains all relevance information for a given search"""
relevance_summaries: dict[str, RelevanceAnalysis]
class OnyxAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None
@@ -78,23 +26,11 @@ class StreamingError(BaseModel):
details: dict | None = None # Additional context (tool name, model name, etc.)
class OnyxAnswer(BaseModel):
answer: str | None
class FileChatDisplay(BaseModel):
file_ids: list[str]
class CustomToolResponse(BaseModel):
response: ToolResultType
tool_name: str
class ToolConfig(BaseModel):
id: int
class ProjectSearchConfig(BaseModel):
"""Configuration for search tool availability in project context."""
@@ -102,83 +38,15 @@ class ProjectSearchConfig(BaseModel):
disable_forced_tool: bool
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
datetime_aware: bool = True
include_citations: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
# Note: prompt_ids removed - prompts are now embedded in personas
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| FileChatDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
)
class CreateChatSessionID(BaseModel):
chat_session_id: UUID
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
AnswerStreamPart = (
Packet
| StreamStopInfo
| MessageResponseIDInfo
| StreamingError
| UserKnowledgeFilePacket
| CreateChatSessionID
)
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
AnswerStream = Iterator[AnswerStreamPart]
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str
answer_citationless: str
top_documents: list[SearchDoc]
error_msg: str | None
message_id: int
citation_info: list[CitationInfo]
class ToolCallResponse(BaseModel):
"""Tool call with full details for non-streaming response."""
@@ -191,8 +59,23 @@ class ToolCallResponse(BaseModel):
pre_reasoning: str | None = None
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str
answer_citationless: str
top_documents: list[SearchDoc]
error_msg: str | None
message_id: int
citation_info: list[CitationInfo]
class ChatFullResponse(BaseModel):
"""Complete non-streaming response with all available data."""
"""Complete non-streaming response with all available data.
NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
"""
# Core response fields
answer: str
@@ -244,6 +127,9 @@ class ChatMessageSimple(BaseModel):
# represents the end of the cacheable prefix
# used for prompt caching
should_cache: bool = False
# When this message represents an injected text file, this is the file's ID.
# Used to detect which file messages survive context-window truncation.
file_id: str | None = None
class ProjectFileMetadata(BaseModel):
@@ -254,6 +140,33 @@ class ProjectFileMetadata(BaseModel):
file_content: str
class FileToolMetadata(BaseModel):
"""Lightweight metadata for exposing files to the FileReaderTool.
Used when files cannot be loaded directly into context (project too large
or persona-attached user_files without direct-load path). The LLM receives
a listing of these so it knows which files it can read via ``read_file``.
"""
file_id: str
filename: str
approx_char_count: int
class ChatHistoryResult(BaseModel):
"""Result of converting chat history to simple format.
Bundles the simple messages with metadata for every text file that was
injected into the history. After context-window truncation drops older
messages, callers compare surviving ``file_id`` tags against this map
to discover "forgotten" files whose metadata should be provided to the
FileReaderTool.
"""
simple_messages: list[ChatMessageSimple]
all_injected_file_metadata: dict[str, FileToolMetadata]
class ExtractedProjectFiles(BaseModel):
project_file_texts: list[str]
project_image_files: list[ChatLoadedFile]
@@ -263,6 +176,9 @@ class ExtractedProjectFiles(BaseModel):
project_file_metadata: list[ProjectFileMetadata]
# None if not a project
project_uncapped_token_count: int | None
# Lightweight metadata for files exposed via FileReaderTool
# (populated when files don't fit in context and vector DB is disabled)
file_metadata_for_tool: list[FileToolMetadata] = []
class LlmStepResult(BaseModel):

View File

@@ -4,12 +4,12 @@ An overview can be found in the README.md file in this directory.
"""
import re
import time
import traceback
from collections.abc import Callable
from contextvars import Token
from uuid import UUID
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
@@ -35,7 +35,7 @@ from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import CreateChatSessionID
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ProjectSearchConfig
from onyx.chat.models import StreamingError
@@ -44,6 +44,7 @@ from onyx.chat.prompt_utils import calculate_reserved_tokens
from onyx.chat.save_chat import save_chat_turn
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
from onyx.chat.stop_signal_checker import reset_cancel_status
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import DocumentSource
@@ -60,6 +61,7 @@ from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import get_project_token_count
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
@@ -77,8 +79,7 @@ from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import OptionalSearchSetting
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
@@ -90,7 +91,11 @@ from onyx.tools.interface import Tool
from onyx.tools.models import SearchToolUsage
from onyx.tools.tool_constructor import construct_tools
from onyx.tools.tool_constructor import CustomToolConfig
from onyx.tools.tool_constructor import FileReaderToolConfig
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.file_reader.file_reader_tool import (
FileReaderTool,
)
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
@@ -100,6 +105,53 @@ logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
class _AvailableFiles(BaseModel):
"""Separated file IDs for the FileReaderTool so it knows which loader to use."""
# IDs from the ``user_file`` table (project / persona-attached files).
user_file_ids: list[UUID] = []
# IDs from the ``file_record`` table (chat-attached files).
chat_file_ids: list[UUID] = []
def _collect_available_file_ids(
chat_history: list[ChatMessage],
project_id: int | None,
user_id: UUID | None,
db_session: Session,
) -> _AvailableFiles:
"""Collect all file IDs the FileReaderTool should be allowed to access.
Returns *separate* lists for chat-attached files (``file_record`` IDs) and
project/user files (``user_file`` IDs) so the tool can pick the right
loader without a try/except fallback."""
chat_file_ids: set[UUID] = set()
user_file_ids: set[UUID] = set()
for msg in chat_history:
if not msg.files:
continue
for fd in msg.files:
try:
chat_file_ids.add(UUID(fd["id"]))
except (ValueError, KeyError):
pass
if project_id:
project_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
for uf in project_files:
user_file_ids.add(uf.id)
return _AvailableFiles(
user_file_ids=list(user_file_ids),
chat_file_ids=list(chat_file_ids),
)
def _should_enable_slack_search(
persona: Persona,
filters: BaseFilters | None,
@@ -232,6 +284,24 @@ def _extract_project_file_texts_and_images(
)
project_image_files.append(chat_loaded_file)
else:
if DISABLE_VECTOR_DB:
# Without a vector DB we can't use project-as-filter search.
# Instead, build lightweight metadata so the LLM can call the
# FileReaderTool to inspect individual files on demand.
file_metadata_for_tool = _build_file_tool_metadata_for_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return ExtractedProjectFiles(
project_file_texts=[],
project_image_files=[],
project_as_filter=False,
total_token_count=0,
project_file_metadata=[],
project_uncapped_token_count=project_tokens,
file_metadata_for_tool=file_metadata_for_tool,
)
project_as_filter = True
return ExtractedProjectFiles(
@@ -244,6 +314,49 @@ def _extract_project_file_texts_and_images(
)
APPROX_CHARS_PER_TOKEN = 4
def _build_file_tool_metadata_for_project(
project_id: int,
user_id: UUID | None,
db_session: Session,
) -> list[FileToolMetadata]:
"""Build lightweight FileToolMetadata for every file in a project.
Used when files are too large to fit in context and the vector DB is
disabled, so the LLM needs to know which files it can read via the
FileReaderTool.
"""
project_user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return [
FileToolMetadata(
file_id=str(uf.id),
filename=uf.name,
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
)
for uf in project_user_files
]
def _build_file_tool_metadata_for_user_files(
user_files: list[UserFile],
) -> list[FileToolMetadata]:
"""Build lightweight FileToolMetadata from a list of UserFile records."""
return [
FileToolMetadata(
file_id=str(uf.id),
filename=uf.name,
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
)
for uf in user_files
]
def _get_project_search_availability(
project_id: int | None,
persona_id: int | None,
@@ -317,7 +430,6 @@ def handle_stream_message_objects(
external_state_container: ChatStateContainer | None = None,
) -> AnswerStream:
tenant_id = get_current_tenant_id()
processing_start_time = time.monotonic()
mock_response_token: Token[str | None] | None = None
llm: LLM | None = None
@@ -330,12 +442,10 @@ def handle_stream_message_objects(
else:
llm_user_identifier = user.email or str(user_id)
if new_msg_req.mock_llm_response is not None:
if not INTEGRATION_TESTS_MODE:
raise ValueError(
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
)
mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response)
if new_msg_req.mock_llm_response is not None and not INTEGRATION_TESTS_MODE:
raise ValueError(
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
)
try:
if not new_msg_req.chat_session_id:
@@ -463,24 +573,68 @@ def handle_stream_message_objects(
chat_history.append(user_message)
# Collect file IDs for the file reader tool *before* summary
# truncation so that files attached to older (summarized-away)
# messages are still accessible via the FileReaderTool.
available_files = _collect_available_file_ids(
chat_history=chat_history,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
# Find applicable summary for the current branch
# Summary applies if its parent_message_id is in current chat_history
summary_message = find_summary_for_branch(db_session, chat_history)
# Collect file metadata from messages that will be dropped by
# summary truncation. These become "pre-summarized" file metadata
# so the forgotten-file mechanism can still tell the LLM about them.
summarized_file_metadata: dict[str, FileToolMetadata] = {}
if summary_message and summary_message.last_summarized_message_id:
cutoff_id = summary_message.last_summarized_message_id
for msg in chat_history:
if msg.id > cutoff_id or not msg.files:
continue
for fd in msg.files:
file_id = fd.get("id")
if not file_id:
continue
summarized_file_metadata[file_id] = FileToolMetadata(
file_id=file_id,
filename=fd.get("name") or "unknown",
# We don't know the exact size without loading the
# file, but 0 signals "unknown" to the LLM.
approx_char_count=0,
)
# Filter chat_history to only messages after the cutoff
chat_history = [m for m in chat_history if m.id > cutoff_id]
user_memory_context = get_memories(user, db_session)
# This is the custom prompt which may come from the Agent or Project. We fetch it earlier because the inner loop
# (run_llm_loop and run_deep_research_llm_loop) should not need to be aware of the Chat History in the DB form processed
# here, however we need this early for token reservation.
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
# When use_memories is disabled, strip memories from the prompt context
# but keep user info/preferences. The full context is still passed
# to the LLM loop for memory tool persistence.
prompt_memory_context = (
user_memory_context
if user.use_memories
else user_memory_context.without_memories()
)
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
custom_agent_prompt or ""
)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=custom_agent_prompt or "",
persona_system_prompt=max_reserved_system_prompt_tokens_str,
token_counter=token_counter,
files=new_msg_req.file_descriptors,
user_memory_context=user_memory_context,
user_memory_context=prompt_memory_context,
)
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
@@ -492,6 +646,16 @@ def handle_stream_message_objects(
db_session=db_session,
)
# When the vector DB is disabled, persona-attached user_files have no
# search pipeline path. Inject them as file_metadata_for_tool so the
# LLM can read them via the FileReaderTool.
if DISABLE_VECTOR_DB and persona.user_files:
persona_file_metadata = _build_file_tool_metadata_for_user_files(
persona.user_files
)
# Merge persona file metadata into the extracted project files
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
# Build a mapping of tool_id to tool_name for history reconstruction
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
@@ -518,6 +682,13 @@ def handle_stream_message_objects(
emitter = get_default_emitter()
# Also grant access to persona-attached user files
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
# Construct tools based on the persona configurations
tool_dict = construct_tools(
persona=persona,
@@ -544,6 +715,10 @@ def handle_stream_message_objects(
additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=available_files.user_file_ids,
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=project_search_config.search_usage,
)
@@ -573,9 +748,12 @@ def handle_stream_message_objects(
reserved_assistant_message_id=assistant_response.id,
)
# Check whether the FileReaderTool is among the constructed tools.
has_file_reader_tool = any(isinstance(t, FileReaderTool) for t in tools)
# Convert the chat history into a simple format that is free of any DB objects
# and is easy to parse for the agent loop
simple_chat_history = convert_chat_history(
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
project_image_files=extracted_project_files.project_image_files,
@@ -583,6 +761,32 @@ def handle_stream_message_objects(
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
)
simple_chat_history = chat_history_result.simple_messages
# Metadata for every text file injected into the history. After
# context-window truncation drops older messages, the LLM loop
# compares surviving file_id tags against this map to discover
# "forgotten" files and provide their metadata to FileReaderTool.
all_injected_file_metadata: dict[str, FileToolMetadata] = (
chat_history_result.all_injected_file_metadata
if has_file_reader_tool
else {}
)
# Merge in file metadata from messages dropped by summary
# truncation. These files are no longer in simple_chat_history
# so they would otherwise be invisible to the forgotten-file
# mechanism. They will always appear as "forgotten" since no
# surviving message carries their file_id tag.
if summarized_file_metadata:
for fid, meta in summarized_file_metadata.items():
all_injected_file_metadata.setdefault(fid, meta)
if all_injected_file_metadata:
logger.debug(
"FileReader: file metadata for LLM: "
f"{[(fid, m.filename) for fid, m in all_injected_file_metadata.items()]}"
)
# Prepend summary message if compression exists
if summary_message is not None:
@@ -623,9 +827,13 @@ def handle_stream_message_objects(
assistant_message=assistant_response,
llm=llm,
reserved_tokens=reserved_token_count,
processing_start_time=processing_start_time,
)
# The stream generator can resume on a different worker thread after early yields.
# Set this right before launching the LLM loop so run_in_background copies the right context.
if new_msg_req.mock_llm_response is not None:
mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response)
# Run the LLM loop with explicit wrapper for stop signal handling
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
# for stop signals. run_llm_loop itself doesn't know about stopping.
@@ -654,6 +862,7 @@ def handle_stream_message_objects(
skip_clarification=skip_clarification,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
all_injected_file_metadata=all_injected_file_metadata,
)
else:
yield from run_chat_loop_with_state_containers(
@@ -675,6 +884,8 @@ def handle_stream_message_objects(
user_identity=user_identity,
chat_session_id=str(chat_session.id),
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
except ValueError as e:
@@ -748,7 +959,6 @@ def llm_loop_completion_handle(
assistant_message: ChatMessage,
llm: LLM,
reserved_tokens: int,
processing_start_time: float | None = None, # noqa: ARG001
) -> None:
chat_session_id = assistant_message.chat_session_id
@@ -811,68 +1021,6 @@ def llm_loop_completion_handle(
)
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User,
db_session: Session,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
bypass_acl: bool = False,
# Additional context that should be included in the chat history, for example:
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
# messages. Both of the below are used for Slack
# NOTE: is not stored in the database, only passed in to the LLM as context
additional_context: str | None = None,
# Slack context for federated Slack search
slack_context: SlackContext | None = None,
) -> AnswerStream:
forced_tool_id = (
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
)
if (
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
):
all_tools = get_tools(db_session)
search_tool_id = next(
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
None,
)
forced_tool_id = search_tool_id
translated_new_msg_req = SendMessageRequest(
message=new_msg_req.message,
llm_override=new_msg_req.llm_override,
mock_llm_response=new_msg_req.mock_llm_response,
allowed_tool_ids=new_msg_req.allowed_tool_ids,
forced_tool_id=forced_tool_id,
file_descriptors=new_msg_req.file_descriptors,
internal_search_filters=(
new_msg_req.retrieval_options.filters
if new_msg_req.retrieval_options
else None
),
deep_research=new_msg_req.deep_research,
parent_message_id=new_msg_req.parent_message_id,
chat_session_id=new_msg_req.chat_session_id,
origin=new_msg_req.origin,
include_citations=new_msg_req.include_citations,
)
return handle_stream_message_objects(
new_msg_req=translated_new_msg_req,
user=user,
db_session=db_session,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
bypass_acl=bypass_acl,
additional_context=additional_context,
slack_context=slack_context,
)
def remove_answer_citations(answer: str) -> str:
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"

View File

@@ -9,13 +9,13 @@ from onyx.db.persona import get_default_behavior_persona
from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
from onyx.prompts.prompt_utils import replace_reminder_tag
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
@@ -25,7 +25,12 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
from onyx.prompts.user_info import TEAM_INFORMATION_PROMPT
from onyx.prompts.user_info import USER_INFORMATION_HEADER
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
from onyx.prompts.user_info import USER_ROLE_PROMPT
from onyx.tools.interface import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
@@ -131,6 +136,59 @@ def build_reminder_message(
return reminder if reminder else None
def _build_user_information_section(
user_memory_context: UserMemoryContext | None,
company_context: str | None,
) -> str:
"""Build the complete '# User Information' section with all sub-sections
in the correct order: Basic Info → Team Info → Preferences → Memories."""
sections: list[str] = []
if user_memory_context:
ctx = user_memory_context
has_basic_info = ctx.user_info.name or ctx.user_info.email or ctx.user_info.role
if has_basic_info:
role_line = (
USER_ROLE_PROMPT.format(user_role=ctx.user_info.role).strip()
if ctx.user_info.role
else ""
)
if role_line:
role_line = "\n" + role_line
sections.append(
BASIC_INFORMATION_PROMPT.format(
user_name=ctx.user_info.name or "",
user_email=ctx.user_info.email or "",
user_role=role_line,
)
)
if company_context:
sections.append(
TEAM_INFORMATION_PROMPT.format(team_information=company_context.strip())
)
if user_memory_context:
ctx = user_memory_context
if ctx.user_preferences:
sections.append(
USER_PREFERENCES_PROMPT.format(user_preferences=ctx.user_preferences)
)
if ctx.memories:
formatted_memories = "\n".join(f"- {memory}" for memory in ctx.memories)
sections.append(
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
)
if not sections:
return ""
return USER_INFORMATION_HEADER + "".join(sections)
def build_system_prompt(
base_system_prompt: str,
datetime_aware: bool = False,
@@ -138,18 +196,12 @@ def build_system_prompt(
tools: Sequence[Tool] | None = None,
should_cite_documents: bool = False,
include_all_guidance: bool = False,
open_ai_formatting_enabled: bool = False,
) -> str:
"""Should only be called with the default behavior system prompt.
If the user has replaced the default behavior prompt with their custom agent prompt, do not call this function.
"""
system_prompt = handle_onyx_date_awareness(base_system_prompt, datetime_aware)
# See https://simonwillison.net/tags/markdown/ for context on why this is needed
# for OpenAI reasoning models to have correct markdown generation
if open_ai_formatting_enabled:
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
# Replace citation guidance placeholder if present
system_prompt, should_append_citation_guidance = replace_citation_guidance_tag(
system_prompt,
@@ -157,16 +209,14 @@ def build_system_prompt(
include_all_guidance=include_all_guidance,
)
# Replace reminder tag placeholder if present
system_prompt = replace_reminder_tag(system_prompt)
company_context = get_company_context()
formatted_user_context = (
user_memory_context.as_formatted_prompt() if user_memory_context else ""
user_info_section = _build_user_information_section(
user_memory_context, company_context
)
if company_context or formatted_user_context:
system_prompt += USER_INFORMATION_HEADER
if company_context:
system_prompt += company_context
if formatted_user_context:
system_prompt += formatted_user_context
system_prompt += user_info_section
# Append citation guidance after company context if placeholder was not present
# This maintains backward compatibility and ensures citations are always enforced when needed

View File

@@ -50,6 +50,17 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
# Controls whether users can use User Knowledge (personal documents) in assistants
DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true"
# Disables vector DB (Vespa/OpenSearch) entirely. When True, connectors and RAG search
# are disabled but core chat, tools, user file uploads, and Projects still work.
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
@@ -75,7 +86,7 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
# Auth Configs
#####
# Upgrades users from disabled auth to basic auth and shows warning.
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
if _auth_type_str == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
@@ -225,11 +236,32 @@ DOCUMENT_INDEX_NAME = "danswer_index"
# OpenSearch Configs
OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost"
OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200)
# TODO(andrei): 60 seconds is too much, we're just setting a high default
# timeout for now to examine why queries are slow.
# NOTE: This timeout applies to all requests the client makes, including bulk
# indexing.
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S") or 60
)
# TODO(andrei): 50 seconds is too much, we're just setting a high default
# timeout for now to examine why queries are slow.
# NOTE: To get useful partial results, this value should be less than the client
# timeout above.
DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
# Profiling adds some overhead to OpenSearch operations. This overhead is
# unknown right now. It is enabled by default so we can get useful logs for
# investigating slow queries. We may never disable it if the overhead is
# minimal.
OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# This is the "base" config for now, the idea is that at least for our dev
# environments we always want to be dual indexing into both OpenSearch and Vespa
@@ -900,6 +932,9 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# Limit on number of users a free trial tenant can invite (cloud only)
NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10"))
# Security and authentication
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""
@@ -942,6 +977,7 @@ API_KEY_HASH_ROUNDS = (
# MCP Server Configs
#####
MCP_SERVER_ENABLED = os.environ.get("MCP_SERVER_ENABLED", "").lower() == "true"
MCP_SERVER_HOST = os.environ.get("MCP_SERVER_HOST", "0.0.0.0")
MCP_SERVER_PORT = int(os.environ.get("MCP_SERVER_PORT") or 8090)
# CORS origins for MCP clients (comma-separated)

View File

@@ -102,7 +102,6 @@ DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_PENDING_USERS_KEY = "PENDING_USERS"
@@ -160,6 +159,8 @@ CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
TMP_DRALPHA_PERSONA_NAME = "KG Beta"
@@ -226,6 +227,9 @@ class DocumentSource(str, Enum):
MOCK_CONNECTOR = "mock_connector"
# Special case for user files
USER_FILE = "user_file"
# Raw files for Craft sandbox access (xlsx, pptx, docx, etc.)
# Uses RAW_BINARY processing mode - no text extraction
CRAFT_FILE = "craft_file"
class FederatedConnectorSource(str, Enum):
@@ -307,6 +311,7 @@ class MessageType(str, Enum):
USER = "user" # HumanMessage
ASSISTANT = "assistant" # AIMessage - Can include tool_calls field for parallel tool calling
TOOL_CALL_RESPONSE = "tool_call_response"
USER_REMINDER = "user_reminder" # Custom Onyx message type which is translated into a USER message when passed to the LLM
class ChatMessageSimpleType(str, Enum):
@@ -331,6 +336,7 @@ class FileOrigin(str, Enum):
CHAT_UPLOAD = "chat_upload"
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
CONNECTOR_METADATA = "connector_metadata"
GENERATED_REPORT = "generated_report"
INDEXING_CHECKPOINT = "indexing_checkpoint"
PLAINTEXT_CACHE = "plaintext_cache"
@@ -396,6 +402,8 @@ class OnyxCeleryQueues:
# Sandbox processing queue
SANDBOX = "sandbox"
OPENSEARCH_MIGRATION = "opensearch_migration"
class OnyxRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
@@ -447,6 +455,9 @@ class OnyxRedisLocks:
CLEANUP_IDLE_SANDBOXES_BEAT_LOCK = "da_lock:cleanup_idle_sandboxes_beat"
CLEANUP_OLD_SNAPSHOTS_BEAT_LOCK = "da_lock:cleanup_old_snapshots_beat"
# Sandbox file sync
SANDBOX_FILE_SYNC_LOCK_PREFIX = "da_lock:sandbox_file_sync"
class OnyxRedisSignals:
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
@@ -577,6 +588,9 @@ class OnyxCeleryTask:
MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK = (
"migrate_documents_from_vespa_to_opensearch_task"
)
MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK = (
"migrate_chunks_from_vespa_to_opensearch_task"
)
# this needs to correspond to the matching entry in supervisord

View File

@@ -1,4 +1,5 @@
import contextvars
import re
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
@@ -14,6 +15,7 @@ from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
@@ -62,11 +64,44 @@ class AirtableClientNotSetUpError(PermissionError):
super().__init__("Airtable Client is not set up, was load_credentials called?")
# Matches URLs like https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
# Captures: base_id (appXXX), table_id (tblYYY), and optionally view_id (viwZZZ)
_AIRTABLE_URL_PATTERN = re.compile(
r"https?://airtable\.com/(app[A-Za-z0-9]+)/(tbl[A-Za-z0-9]+)(?:/(viw[A-Za-z0-9]+))?",
)
def parse_airtable_url(
url: str,
) -> tuple[str, str, str | None]:
"""Parse an Airtable URL into (base_id, table_id, view_id).
Accepts URLs like:
https://airtable.com/appXXX/tblYYY
https://airtable.com/appXXX/tblYYY/viwZZZ
https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
Returns:
(base_id, table_id, view_id or None)
Raises:
ValueError if the URL doesn't match the expected format.
"""
match = _AIRTABLE_URL_PATTERN.search(url.strip())
if not match:
raise ValueError(
f"Could not parse Airtable URL: '{url}'. "
"Expected format: https://airtable.com/appXXX/tblYYY[/viwZZZ]"
)
return match.group(1), match.group(2), match.group(3)
class AirtableConnector(LoadConnector):
def __init__(
self,
base_id: str,
table_name_or_id: str,
base_id: str = "",
table_name_or_id: str = "",
airtable_url: str = "",
treat_all_non_attachment_fields_as_metadata: bool = False,
view_id: str | None = None,
share_id: str | None = None,
@@ -75,16 +110,33 @@ class AirtableConnector(LoadConnector):
"""Initialize an AirtableConnector.
Args:
base_id: The ID of the Airtable base to connect to
table_name_or_id: The name or ID of the table to index
base_id: The ID of the Airtable base (not required when airtable_url is set)
table_name_or_id: The name or ID of the table (not required when airtable_url is set)
airtable_url: An Airtable URL to parse base_id, table_id, and view_id from.
Overrides base_id, table_name_or_id, and view_id if provided.
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
view_id: Optional ID of a specific view to use
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
share_id: Optional ID of a "share" to use for generating record URLs
batch_size: Number of records to process in each batch
Mode is auto-detected: if a specific table is identified (via URL or
base_id + table_name_or_id), the connector indexes that single table.
Otherwise, it discovers and indexes all accessible bases and tables.
"""
# If a URL is provided, parse it to extract base_id, table_id, and view_id
if airtable_url:
parsed_base_id, parsed_table_id, parsed_view_id = parse_airtable_url(
airtable_url
)
base_id = parsed_base_id
table_name_or_id = parsed_table_id
if parsed_view_id:
view_id = parsed_view_id
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.index_all = not (base_id and table_name_or_id)
self.view_id = view_id
self.share_id = share_id
self.batch_size = batch_size
@@ -103,6 +155,33 @@ class AirtableConnector(LoadConnector):
raise AirtableClientNotSetUpError()
return self._airtable_client
def validate_connector_settings(self) -> None:
if self.index_all:
try:
bases = self.airtable_client.bases()
if not bases:
raise ConnectorValidationError(
"No bases found. Ensure your API token has access to at least one base."
)
except ConnectorValidationError:
raise
except Exception as e:
raise ConnectorValidationError(f"Failed to list Airtable bases: {e}")
else:
if not self.base_id or not self.table_name_or_id:
raise ConnectorValidationError(
"A valid Airtable URL or base_id and table_name_or_id are required "
"when not using index_all mode."
)
try:
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
table.schema()
except Exception as e:
raise ConnectorValidationError(
f"Failed to access table '{self.table_name_or_id}' "
f"in base '{self.base_id}': {e}"
)
@classmethod
def _get_record_url(
cls,
@@ -267,6 +346,7 @@ class AirtableConnector(LoadConnector):
field_name: str,
field_info: Any,
field_type: str,
base_id: str,
table_id: str,
view_id: str | None,
record_id: str,
@@ -291,7 +371,7 @@ class AirtableConnector(LoadConnector):
field_name=field_name,
field_info=field_info,
field_type=field_type,
base_id=self.base_id,
base_id=base_id,
table_id=table_id,
view_id=view_id,
record_id=record_id,
@@ -326,15 +406,17 @@ class AirtableConnector(LoadConnector):
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
base_id: str,
base_name: str | None = None,
) -> Document | None:
"""Process a single Airtable record into a Document.
Args:
record: The Airtable record to process
table_schema: Schema information for the table
table_name: Name of the table
table_id: ID of the table
primary_field_name: Name of the primary field, if any
base_id: The ID of the base this record belongs to
base_name: The name of the base (used in semantic ID for index_all mode)
Returns:
Document object representing the record
@@ -367,6 +449,7 @@ class AirtableConnector(LoadConnector):
field_name=field_name,
field_info=field_val,
field_type=field_type,
base_id=base_id,
table_id=table_id,
view_id=view_id,
record_id=record_id,
@@ -379,11 +462,26 @@ class AirtableConnector(LoadConnector):
logger.warning(f"No sections found for record {record_id}")
return None
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
# Include base name in semantic ID only in index_all mode
if self.index_all and base_name:
semantic_id = (
f"{base_name} > {table_name}: {primary_field_value}"
if primary_field_value
else f"{base_name} > {table_name}"
)
else:
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
# Build hierarchy source_path for Craft file system subdirectory structure.
# This creates: airtable/{base_name}/{table_name}/record.json
source_path: list[str] = []
if base_name:
source_path.append(base_name)
source_path.append(table_name)
return Document(
id=f"airtable__{record_id}",
@@ -391,19 +489,39 @@ class AirtableConnector(LoadConnector):
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
doc_metadata={
"hierarchy": {
"source_path": source_path,
"base_id": base_id,
"table_id": table_id,
"table_name": table_name,
**({"base_name": base_name} if base_name else {}),
}
},
)
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from the table.
def _resolve_base_name(self, base_id: str) -> str | None:
"""Try to resolve a human-readable base name from the API."""
try:
for base_info in self.airtable_client.bases():
if base_info.id == base_id:
return base_info.name
except Exception:
logger.debug(f"Could not resolve base name for {base_id}")
return None
NOTE: Airtable does not support filtering by time updated, so
we have to fetch all records every time.
"""
if not self.airtable_client:
raise AirtableClientNotSetUpError()
def _index_table(
self,
base_id: str,
table_name_or_id: str,
base_name: str | None = None,
) -> GenerateDocumentsOutput:
"""Index all records from a single table. Yields batches of Documents."""
# Resolve base name for hierarchy if not provided
if base_name is None:
base_name = self._resolve_base_name(base_id)
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
table = self.airtable_client.table(base_id, table_name_or_id)
records = table.all()
table_schema = table.schema()
@@ -415,21 +533,25 @@ class AirtableConnector(LoadConnector):
primary_field_name = field.name
break
logger.info(f"Starting to process Airtable records for {table.name}.")
logger.info(
f"Processing {len(records)} records from table "
f"'{table_schema.name}' in base '{base_name or base_id}'."
)
if not records:
return
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
record_documents: list[Document | HierarchyNode] = []
# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents = []
record_documents: list[Document | HierarchyNode] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record: dict[Future, RecordDict] = {}
future_to_record: dict[Future[Document | None], RecordDict] = {}
for record in batch_records:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
@@ -440,6 +562,8 @@ class AirtableConnector(LoadConnector):
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
base_id=base_id,
base_name=base_name,
)
] = record
@@ -454,9 +578,58 @@ class AirtableConnector(LoadConnector):
logger.exception(f"Failed to process record {record['id']}")
raise e
yield record_documents
record_documents = []
if record_documents:
yield record_documents
# Yield any remaining records
if record_documents:
yield record_documents
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from one or all tables.
NOTE: Airtable does not support filtering by time updated, so
we have to fetch all records every time.
"""
if not self.airtable_client:
raise AirtableClientNotSetUpError()
if self.index_all:
yield from self._load_all()
else:
yield from self._index_table(
base_id=self.base_id,
table_name_or_id=self.table_name_or_id,
)
def _load_all(self) -> GenerateDocumentsOutput:
"""Discover all bases and tables, then index everything."""
bases = self.airtable_client.bases()
logger.info(f"Discovered {len(bases)} Airtable base(s).")
for base_info in bases:
base_id = base_info.id
base_name = base_info.name
logger.info(f"Listing tables for base '{base_name}' ({base_id}).")
try:
base = self.airtable_client.base(base_id)
tables = base.tables()
except Exception:
logger.exception(
f"Failed to list tables for base '{base_name}' ({base_id}), skipping."
)
continue
logger.info(f"Found {len(tables)} table(s) in base '{base_name}'.")
for table in tables:
try:
yield from self._index_table(
base_id=base_id,
table_name_or_id=table.id,
base_name=base_name,
)
except Exception:
logger.exception(
f"Failed to index table '{table.name}' ({table.id}) "
f"in base '{base_name}' ({base_id}), skipping."
)
continue

View File

@@ -171,6 +171,7 @@ def process_onyx_metadata(
return (
OnyxMetadata(
document_id=metadata.get("id"),
source_type=source_type,
link=metadata.get("link"),
file_display_name=metadata.get("file_display_name"),

View File

@@ -1,3 +1,4 @@
import json
import os
from datetime import datetime
from datetime import timezone
@@ -107,7 +108,7 @@ def _process_file(
# These metadata items are not settable by the user
source_type = onyx_metadata.source_type or DocumentSource.FILE
doc_id = f"FILE_CONNECTOR__{file_id}"
doc_id = onyx_metadata.document_id or f"FILE_CONNECTOR__{file_id}"
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
@@ -240,29 +241,49 @@ class LocalFileConnector(LoadConnector):
self,
file_locations: list[Path | str],
file_names: list[str] | None = None, # noqa: ARG002
zip_metadata: dict[str, Any] | None = None,
zip_metadata_file_id: str | None = None,
zip_metadata: dict[str, Any] | None = None, # Deprecated, for backwards compat
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [str(loc) for loc in file_locations]
self.batch_size = batch_size
self.pdf_pass: str | None = None
self.zip_metadata = zip_metadata or {}
self._zip_metadata_file_id = zip_metadata_file_id
self._zip_metadata_deprecated = zip_metadata
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
return None
def _get_file_metadata(self, file_name: str) -> dict[str, Any]:
return self.zip_metadata.get(file_name, {}) or self.zip_metadata.get(
os.path.basename(file_name), {}
)
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Iterates over each file path, fetches from Postgres, tries to parse text
or images, and yields Document batches.
"""
# Load metadata dict at start (from file store or deprecated inline format)
zip_metadata: dict[str, Any] = {}
if self._zip_metadata_file_id:
try:
file_store = get_default_file_store()
metadata_io = file_store.read_file(
file_id=self._zip_metadata_file_id, mode="b"
)
metadata_bytes = metadata_io.read()
loaded_metadata = json.loads(metadata_bytes)
if isinstance(loaded_metadata, list):
zip_metadata = {d["filename"]: d for d in loaded_metadata}
else:
zip_metadata = loaded_metadata
except Exception as e:
logger.warning(f"Failed to load metadata from file store: {e}")
elif self._zip_metadata_deprecated:
logger.warning(
"Using deprecated inline zip_metadata dict. "
"Re-upload files to use the new file store format."
)
zip_metadata = self._zip_metadata_deprecated
documents: list[Document | HierarchyNode] = []
for file_id in self.file_locations:
@@ -273,7 +294,9 @@ class LocalFileConnector(LoadConnector):
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
continue
metadata = self._get_file_metadata(file_record.display_name)
metadata = zip_metadata.get(
file_record.display_name, {}
) or zip_metadata.get(os.path.basename(file_record.display_name), {})
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
@@ -298,7 +321,6 @@ if __name__ == "__main__":
connector = LocalFileConnector(
file_locations=[os.environ["TEST_FILE"]],
file_names=[os.environ["TEST_FILE"]],
zip_metadata={},
)
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
doc_batches = connector.load_from_state()

View File

@@ -523,6 +523,22 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
sleep_after_rate_limit_exception(github_client)
return self.get_all_repos(github_client, attempt_num + 1)
def fetch_configured_repos(self) -> list[Repository.Repository]:
"""
Fetch the configured repositories based on the connector settings.
Returns:
list[Repository.Repository]: The configured repositories.
"""
assert self.github_client is not None # mypy
if self.repositories:
if "," in self.repositories:
return self.get_github_repos(self.github_client)
else:
return [self.get_github_repo(self.github_client)]
else:
return self.get_all_repos(self.github_client)
def _pull_requests_func(
self, repo: Repository.Repository
) -> Callable[[], PaginatedList[PullRequest]]:
@@ -551,17 +567,7 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
# First run of the connector, fetch all repos and store in checkpoint
if checkpoint.cached_repo_ids is None:
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self.get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self.get_github_repo(self.github_client)]
else:
# All repositories
repos = self.get_all_repos(self.github_client)
repos = self.fetch_configured_repos()
if not repos:
checkpoint.has_more = False
return checkpoint

View File

@@ -474,8 +474,9 @@ class ConnectorStopSignal(Exception):
class OnyxMetadata(BaseModel):
# Note that doc_id cannot be overriden here as it may cause issues
# with the display functionalities in the UI. Ask @chris if clarification is needed.
# Careful overriding the document_id, may cause visual issues in the UI.
# Kept here for API based use cases mostly
document_id: str | None = None
source_type: DocumentSource | None = None
link: str | None = None
file_display_name: str | None = None

View File

@@ -79,6 +79,13 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
ASPX_EXTENSION = ".aspx"
# The office365 library's ClientContext caches the access token from
# The office365 library's ClientContext caches the access token from its
# first request and never re-invokes the token callback. Microsoft access
# tokens live ~60-75 minutes, so we recreate the cached ClientContext every
# 30 minutes to let MSAL transparently handle token refresh.
_REST_CTX_MAX_AGE_S = 30 * 60
class SiteDescriptor(BaseModel):
"""Data class for storing SharePoint site information.
@@ -114,11 +121,10 @@ def sleep_and_retry(
try:
return query_obj.execute_query()
except ClientRequestException as e:
if (
e.response is not None
and e.response.status_code in [429, 503]
and attempt < max_retries
):
status = e.response.status_code if e.response is not None else None
# 429 / 503 — rate limit or transient error. Back off and retry.
if status in (429, 503) and attempt < max_retries:
logger.warning(
f"Rate limit exceeded on {method_name}, attempt {attempt + 1}/{max_retries + 1}, sleeping and retrying"
)
@@ -131,13 +137,15 @@ def sleep_and_retry(
logger.info(f"Sleeping for {sleep_time} seconds before retry")
time.sleep(sleep_time)
else:
# Either not a rate limit error, or we've exhausted retries
if e.response is not None and e.response.status_code == 429:
logger.error(
f"Rate limit retry exhausted for {method_name} after {max_retries} attempts"
)
raise e
continue
# Non-retryable error or retries exhausted — log details and raise.
if e.response is not None:
logger.error(
f"SharePoint request failed for {method_name}: "
f"status={status}, "
)
raise e
class SharepointConnectorCheckpoint(ConnectorCheckpoint):
@@ -713,6 +721,10 @@ class SharepointConnector(
self.include_site_pages = include_site_pages
self.include_site_documents = include_site_documents
self.sp_tenant_domain: str | None = None
self._credential_json: dict[str, Any] | None = None
self._cached_rest_ctx: ClientContext | None = None
self._cached_rest_ctx_url: str | None = None
self._cached_rest_ctx_created_at: float = 0.0
def validate_connector_settings(self) -> None:
# Validate that at least one content type is enabled
@@ -738,6 +750,44 @@ class SharepointConnector(
return self._graph_client
def _create_rest_client_context(self, site_url: str) -> ClientContext:
"""Return a ClientContext for SharePoint REST API calls, with caching.
The office365 library's ClientContext caches the access token from its
first request and never re-invokes the token callback. We cache the
context and recreate it when the site URL changes or after
``_REST_CTX_MAX_AGE_S``. On recreation we also call
``load_credentials`` to build a fresh MSAL app with an empty token
cache, guaranteeing a brand-new token from Azure AD."""
elapsed = time.monotonic() - self._cached_rest_ctx_created_at
if (
self._cached_rest_ctx is not None
and self._cached_rest_ctx_url == site_url
and elapsed <= _REST_CTX_MAX_AGE_S
):
return self._cached_rest_ctx
if self._credential_json:
logger.info(
"Rebuilding SharePoint REST client context "
"(elapsed=%.0fs, site_changed=%s)",
elapsed,
self._cached_rest_ctx_url != site_url,
)
self.load_credentials(self._credential_json)
if not self.msal_app or not self.sp_tenant_domain:
raise RuntimeError("MSAL app or tenant domain is not set")
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
self._cached_rest_ctx = ClientContext(site_url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
self._cached_rest_ctx_url = site_url
self._cached_rest_ctx_created_at = time.monotonic()
return self._cached_rest_ctx
@staticmethod
def _strip_share_link_tokens(path: str) -> list[str]:
# Share links often include a token prefix like /:f:/r/ or /:x:/r/.
@@ -1177,21 +1227,6 @@ class SharepointConnector(
# goes over all urls, converts them into SlimDocument objects and then yields them in batches
doc_batch: list[SlimDocument | HierarchyNode] = []
for site_descriptor in site_descriptors:
ctx: ClientContext | None = None
if self.msal_app and self.sp_tenant_domain:
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
else:
raise RuntimeError("MSAL app or tenant domain is not set")
if ctx is None:
logger.warning("ClientContext is not set, skipping permissions")
continue
site_url = site_descriptor.url
# Yield site hierarchy node using helper
@@ -1230,6 +1265,7 @@ class SharepointConnector(
try:
logger.debug(f"Processing: {driveitem.web_url}")
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_driveitem_to_slim_document(
driveitem, drive_name, ctx, self.graph_client
@@ -1249,6 +1285,7 @@ class SharepointConnector(
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
)
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page, ctx, self.graph_client
@@ -1260,6 +1297,7 @@ class SharepointConnector(
yield doc_batch
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._credential_json = credentials
auth_method = credentials.get(
"authentication_method", SharepointAuthMethod.CLIENT_SECRET.value
)
@@ -1676,17 +1714,6 @@ class SharepointConnector(
)
logger.debug(f"Time range: {start_dt} to {end_dt}")
ctx: ClientContext | None = None
if include_permissions:
if self.msal_app and self.sp_tenant_domain:
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
else:
raise RuntimeError("MSAL app or tenant domain is not set")
# At this point current_drive_name should be set from popleft()
current_drive_name = checkpoint.current_drive_name
if current_drive_name is None:
@@ -1781,6 +1808,10 @@ class SharepointConnector(
)
try:
ctx: ClientContext | None = None
if include_permissions:
ctx = self._create_rest_client_context(site_descriptor.url)
doc = _convert_driveitem_to_document_with_permissions(
driveitem,
current_drive_name,
@@ -1846,20 +1877,13 @@ class SharepointConnector(
site_pages = self._fetch_site_pages(
site_descriptor, start=start_dt, end=end_dt
)
client_ctx: ClientContext | None = None
if include_permissions:
if self.msal_app and self.sp_tenant_domain:
msal_app = self.msal_app
sp_tenant_domain = self.sp_tenant_domain
client_ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
else:
raise RuntimeError("MSAL app or tenant domain is not set")
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
)
client_ctx: ClientContext | None = None
if include_permissions:
client_ctx = self._create_rest_client_context(site_descriptor.url)
yield (
_convert_sitepage_to_document(
site_page,

View File

@@ -308,6 +308,18 @@ def default_msg_filter(message: MessageType) -> SlackMessageFilterReason | None:
return None
def _bot_inclusive_msg_filter(
message: MessageType,
) -> SlackMessageFilterReason | None:
"""Like default_msg_filter but allows bot/app messages through.
Only filters out disallowed subtypes (channel_join, channel_leave, etc.).
"""
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
return SlackMessageFilterReason.DISALLOWED
return None
def filter_channels(
all_channels: list[ChannelType],
channels_to_connect: list[str] | None,
@@ -654,12 +666,18 @@ class SlackConnector(
# if specified, will treat the specified channel strings as
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
# if True, messages from bots/apps will be indexed instead of filtered out
include_bot_messages: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
num_threads: int = SLACK_NUM_THREADS,
use_redis: bool = True,
) -> None:
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.include_bot_messages = include_bot_messages
self.msg_filter_func = (
_bot_inclusive_msg_filter if include_bot_messages else default_msg_filter
)
self.batch_size = batch_size
self.num_threads = num_threads
self.client: WebClient | None = None
@@ -839,6 +857,7 @@ class SlackConnector(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
msg_filter_func=self.msg_filter_func,
callback=callback,
workspace_url=self._workspace_url,
)
@@ -926,6 +945,7 @@ class SlackConnector(
try:
num_bot_filtered_messages = 0
num_other_filtered_messages = 0
oldest = str(start) if start else None
latest = str(end)
@@ -984,6 +1004,7 @@ class SlackConnector(
user_cache=self.user_cache,
seen_thread_ts=seen_thread_ts,
channel_access=checkpoint.current_channel_access,
msg_filter_func=self.msg_filter_func,
)
)
@@ -1003,7 +1024,13 @@ class SlackConnector(
seen_thread_ts.add(thread_or_message_ts)
elif processed_slack_message.filter_reason:
num_bot_filtered_messages += 1
if (
processed_slack_message.filter_reason
== SlackMessageFilterReason.BOT
):
num_bot_filtered_messages += 1
else:
num_other_filtered_messages += 1
elif failure:
yield failure
@@ -1023,10 +1050,14 @@ class SlackConnector(
range_total = 1
range_percent_complete = range_complete / range_total * 100.0
logger.info(
num_filtered = num_bot_filtered_messages + num_other_filtered_messages
log_func = logger.warning if num_bot_filtered_messages > 0 else logger.info
log_func(
f"Message processing stats: "
f"batch_len={len(message_batch)} "
f"batch_yielded={num_threads_processed} "
f"filtered={num_filtered} "
f"(bot={num_bot_filtered_messages} other={num_other_filtered_messages}) "
f"total_threads_seen={len(seen_thread_ts)}"
)
@@ -1040,7 +1071,8 @@ class SlackConnector(
checkpoint.seen_thread_ts = list(seen_thread_ts)
checkpoint.channel_completion_map[channel["id"]] = new_oldest
# bypass channels where the first set of messages seen are all bots
# bypass channels where the first set of messages seen are all
# filtered (bots + disallowed subtypes like channel_join)
# check at least MIN_BOT_MESSAGE_THRESHOLD messages are in the batch
# we shouldn't skip based on a small sampling of messages
if (
@@ -1048,7 +1080,7 @@ class SlackConnector(
and len(message_batch) > SlackConnector.BOT_CHANNEL_MIN_BATCH_SIZE
):
if (
num_bot_filtered_messages
num_filtered
> SlackConnector.BOT_CHANNEL_PERCENTAGE_THRESHOLD
* len(message_batch)
):

View File

@@ -20,7 +20,7 @@ from onyx.onyxbot.slack.models import ChannelType
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
from onyx.tracing.llm_utils import llm_generation_span
from onyx.tracing.llm_utils import record_llm_span_output
from onyx.tracing.llm_utils import record_llm_response
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -201,8 +201,8 @@ def extract_date_range_from_query(
llm=llm, flow="slack_date_extraction", input_messages=[prompt_msg]
) as span_generation:
llm_response = llm.invoke(prompt_msg)
record_llm_response(span_generation, llm_response)
response = llm_response_to_string(llm_response)
record_llm_span_output(span_generation, response, llm_response.usage)
response_clean = _parse_llm_code_block_response(response)
@@ -606,8 +606,8 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
llm=llm, flow="slack_query_expansion", input_messages=[prompt]
) as span_generation:
llm_response = llm.invoke(prompt)
record_llm_response(span_generation, llm_response)
response = llm_response_to_string(llm_response)
record_llm_span_output(span_generation, response, llm_response.usage)
response_clean = _parse_llm_code_block_response(response)

View File

@@ -6,7 +6,6 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from onyx.configs.constants import DocumentSource
from onyx.db.models import SearchSettings
@@ -97,21 +96,6 @@ class IndexFilters(BaseFilters, UserFileFilters, AssistantKnowledgeFilters):
tenant_id: str | None = None
class ChunkContext(BaseModel):
# If not specified (None), picked up from Persona settings if there is space
# if specified (even if 0), it always uses the specified number of chunks above and below
chunks_above: int | None = None
chunks_below: int | None = None
full_doc: bool = False
@field_validator("chunks_above", "chunks_below")
@classmethod
def check_non_negative(cls, value: int, field: Any) -> int:
if value is not None and value < 0:
raise ValueError(f"{field.name} must be non-negative")
return value
class BasicChunkRequest(BaseModel):
query: str

View File

@@ -1,32 +0,0 @@
# Note, this file and all SavedSearchSettings things are not being used in live code paths (at least at the time of this comment)
# Kept around as it may be useful in the future
from typing import cast
from onyx.configs.constants import KV_SEARCH_SETTINGS
from onyx.context.search.models import SavedSearchSettings
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_kv_search_settings() -> SavedSearchSettings | None:
"""Get all user configured search settings which affect the search pipeline
Note: KV store is used in this case since there is no need to rollback the value or any need to audit past values
Note: for now we can't cache this value because if the API server is scaled, the cache could be out of sync
if the value is updated by another process/instance of the API server. If this reads from an in memory cache like
reddis then it will be ok. Until then this has some performance implications (though minor)
"""
kv_store = get_kv_store()
try:
return SavedSearchSettings(**cast(dict, kv_store.load(KV_SEARCH_SETTINGS)))
except KvKeyNotFoundError:
return None
except Exception as e:
logger.error(f"Error loading search settings: {e}")
# Wiping it so that next server startup, it can load the defaults
# or the user can set it via the API/UI
kv_store.delete(KV_SEARCH_SETTINGS)
return None

View File

@@ -19,7 +19,6 @@ from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.chat.models import DocumentRelevance
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
@@ -672,27 +671,6 @@ def set_as_latest_chat_message(
db_session.commit()
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[DBSearchDoc],
relevance_summary: DocumentRelevance,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
search_doc.document_id
)
if relevance_data is not None:
db_session.execute(
update(DBSearchDoc)
.where(DBSearchDoc.id == search_doc.id)
.values(
is_relevant=relevance_data.relevant,
relevance_explanation=relevance_data.content,
)
)
db_session.commit()
def _sanitize_for_postgres(value: str) -> str:
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
sanitized = value.replace("\x00", "")

View File

@@ -6,6 +6,8 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
@@ -226,6 +228,50 @@ def get_documents_by_ids(
return list(documents)
def get_documents_by_source(
db_session: Session,
source: DocumentSource,
creator_id: UUID | None = None,
) -> list[DbDocument]:
"""Get all documents associated with a specific source type.
This queries through the connector relationship to find all documents
that were indexed by connectors of the given source type.
Args:
db_session: Database session
source: The document source type to filter by
creator_id: If provided, only return documents from connectors
created by this user. Filters via ConnectorCredentialPair.
"""
stmt = (
select(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.where(Connector.source == source)
)
if creator_id is not None:
stmt = stmt.where(ConnectorCredentialPair.creator_id == creator_id)
stmt = stmt.distinct()
documents = db_session.execute(stmt).scalars().all()
return list(documents)
def _apply_last_updated_cursor_filter(
stmt: Select,
cursor_last_modified: datetime | None,
@@ -1527,3 +1573,40 @@ def get_document_kg_entities_and_relationships(
def get_num_chunks_for_document(db_session: Session, document_id: str) -> int:
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none() or 0
def update_document_metadata__no_commit(
db_session: Session,
document_id: str,
doc_metadata: dict[str, Any],
) -> None:
"""Update the doc_metadata field for a document.
Note: Does not commit. Caller is responsible for committing.
Args:
db_session: Database session
document_id: The ID of the document to update
doc_metadata: The new metadata dictionary to set
"""
stmt = (
update(DbDocument)
.where(DbDocument.id == document_id)
.values(doc_metadata=doc_metadata)
)
db_session.execute(stmt)
def delete_document_by_id__no_commit(
db_session: Session,
document_id: str,
) -> None:
"""Delete a single document and its connector credential pair relationships.
Note: Does not commit. Caller is responsible for committing.
This uses delete_documents_complete__no_commit which handles
all foreign key relationships (KG entities, relationships, chunk stats,
cc pair associations, feedback, tags).
"""
delete_documents_complete__no_commit(db_session, [document_id])

View File

@@ -60,7 +60,8 @@ class ProcessingMode(str, PyEnum):
"""Determines how documents are processed after fetching."""
REGULAR = "REGULAR" # Full pipeline: chunk → embed → Vespa
FILE_SYSTEM = "FILE_SYSTEM" # Write to file system only
FILE_SYSTEM = "FILE_SYSTEM" # Write to file system only (JSON documents)
RAW_BINARY = "RAW_BINARY" # Write raw binary to S3 (no text extraction)
class SyncType(str, PyEnum):
@@ -197,6 +198,12 @@ class ThemePreference(str, PyEnum):
SYSTEM = "system"
class DefaultAppMode(str, PyEnum):
AUTO = "AUTO"
CHAT = "CHAT"
SEARCH = "SEARCH"
class SwitchoverType(str, PyEnum):
REINDEX = "reindex"
ACTIVE_ONLY = "active_only"
@@ -289,4 +296,4 @@ class HierarchyNodeType(str, PyEnum):
class LLMModelFlowType(str, PyEnum):
CHAT = "chat"
VISION = "vision"
EMBEDDINGS = "embeddings"
CONTEXTUAL_RAG = "contextual_rag"

View File

@@ -231,10 +231,11 @@ def upsert_llm_provider(
# Set to None if the dict is empty after filtering
custom_config = custom_config or None
api_base = llm_provider_upsert_request.api_base or None
existing_llm_provider.provider = llm_provider_upsert_request.provider
# EncryptedString accepts str for writes, returns SensitiveValue for reads
existing_llm_provider.api_key = llm_provider_upsert_request.api_key # type: ignore[assignment]
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
existing_llm_provider.api_base = api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
# TODO: Remove default model name on api change
@@ -508,6 +509,12 @@ def fetch_default_vision_model(db_session: Session) -> ModelConfiguration | None
return fetch_default_model(db_session, LLMModelFlowType.VISION)
def fetch_default_contextual_rag_model(
db_session: Session,
) -> ModelConfiguration | None:
return fetch_default_model(db_session, LLMModelFlowType.CONTEXTUAL_RAG)
def fetch_default_model(
db_session: Session,
flow_type: LLMModelFlowType,
@@ -645,6 +652,73 @@ def update_default_vision_provider(
)
def update_no_default_contextual_rag_provider(
db_session: Session,
) -> None:
db_session.execute(
update(LLMModelFlow)
.where(
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CONTEXTUAL_RAG,
LLMModelFlow.is_default == True, # noqa: E712
)
.values(is_default=False)
)
db_session.commit()
def update_default_contextual_model(
db_session: Session,
enable_contextual_rag: bool,
contextual_rag_llm_provider: str | None,
contextual_rag_llm_name: str | None,
) -> None:
"""Sets or clears the default contextual RAG model.
Should be called whenever the PRESENT search settings change
(e.g. inline update or FUTURE → PRESENT swap).
"""
if (
not enable_contextual_rag
or not contextual_rag_llm_name
or not contextual_rag_llm_provider
):
update_no_default_contextual_rag_provider(db_session=db_session)
return
provider = fetch_existing_llm_provider(
name=contextual_rag_llm_provider, db_session=db_session
)
if not provider:
raise ValueError(f"Provider '{contextual_rag_llm_provider}' not found")
model_config = next(
(
mc
for mc in provider.model_configurations
if mc.name == contextual_rag_llm_name
),
None,
)
if not model_config:
raise ValueError(
f"Model '{contextual_rag_llm_name}' not found for provider '{contextual_rag_llm_provider}'"
)
add_model_to_flow(
db_session=db_session,
model_configuration_id=model_config.id,
flow_type=LLMModelFlowType.CONTEXTUAL_RAG,
)
_update_default_model(
db_session=db_session,
provider_id=provider.id,
model=contextual_rag_llm_name,
flow_type=LLMModelFlowType.CONTEXTUAL_RAG,
)
return
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
"""Fetch all LLM providers that are in Auto mode."""
query = (
@@ -759,9 +833,18 @@ def create_new_flow_mapping__no_commit(
)
flow = result.scalar()
if not flow:
# Row already exists — fetch it
flow = db_session.scalar(
select(LLMModelFlow).where(
LLMModelFlow.model_configuration_id == model_configuration_id,
LLMModelFlow.llm_model_flow_type == flow_type,
)
)
if not flow:
raise ValueError(
f"Failed to create new flow mapping for model_configuration_id={model_configuration_id} and flow_type={flow_type}"
f"Failed to create or find flow mapping for "
f"model_configuration_id={model_configuration_id} and flow_type={flow_type}"
)
return flow
@@ -899,3 +982,18 @@ def _update_default_model(
model_config.is_visible = True
db_session.commit()
def add_model_to_flow(
db_session: Session,
model_configuration_id: int,
flow_type: LLMModelFlowType,
) -> None:
# Function does nothing on conflict
create_new_flow_mapping__no_commit(
db_session=db_session,
model_configuration_id=model_configuration_id,
flow_type=flow_type,
)
db_session.commit()

View File

@@ -1,3 +1,5 @@
from uuid import UUID
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy import select
@@ -5,10 +7,8 @@ from sqlalchemy.orm import Session
from onyx.db.models import Memory
from onyx.db.models import User
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
from onyx.prompts.user_info import USER_ROLE_PROMPT
MAX_MEMORIES_PER_USER = 10
class UserInfo(BaseModel):
@@ -27,10 +27,20 @@ class UserInfo(BaseModel):
class UserMemoryContext(BaseModel):
model_config = ConfigDict(frozen=True)
user_id: UUID | None = None
user_info: UserInfo
user_preferences: str | None = None
memories: tuple[str, ...] = ()
def without_memories(self) -> "UserMemoryContext":
"""Return a copy with memories cleared but user info/preferences intact."""
return UserMemoryContext(
user_id=self.user_id,
user_info=self.user_info,
user_preferences=self.user_preferences,
memories=(),
)
def as_formatted_list(self) -> list[str]:
"""Returns combined list of user info, preferences, and memories."""
result = []
@@ -45,50 +55,8 @@ class UserMemoryContext(BaseModel):
result.extend(self.memories)
return result
def as_formatted_prompt(self) -> str:
"""Returns structured prompt sections for the system prompt."""
has_basic_info = (
self.user_info.name or self.user_info.email or self.user_info.role
)
if not has_basic_info and not self.user_preferences and not self.memories:
return ""
sections: list[str] = []
if has_basic_info:
role_line = (
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
if self.user_info.role
else ""
)
if role_line:
role_line = "\n" + role_line
sections.append(
BASIC_INFORMATION_PROMPT.format(
user_name=self.user_info.name or "",
user_email=self.user_info.email or "",
user_role=role_line,
)
)
if self.user_preferences:
sections.append(
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
)
if self.memories:
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
sections.append(
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
)
return "".join(sections)
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
if not user.use_memories:
return UserMemoryContext(user_info=UserInfo())
user_info = UserInfo(
name=user.personal_name,
role=user.personal_role,
@@ -105,7 +73,57 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
memories = tuple(memory.memory_text for memory in memory_rows if memory.memory_text)
return UserMemoryContext(
user_id=user.id,
user_info=user_info,
user_preferences=user_preferences,
memories=memories,
)
def add_memory(
user_id: UUID,
memory_text: str,
db_session: Session,
) -> Memory:
"""Insert a new Memory row for the given user.
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
one (lowest id) is deleted before inserting the new one.
"""
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory
def update_memory_at_index(
user_id: UUID,
index: int,
new_text: str,
db_session: Session,
) -> Memory | None:
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
Returns the updated Memory row, or None if the index is out of range.
"""
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if index < 0 or index >= len(memory_rows):
return None
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory

View File

@@ -75,6 +75,7 @@ from onyx.db.enums import (
MCPServerStatus,
LLMModelFlowType,
ThemePreference,
DefaultAppMode,
SwitchoverType,
)
from onyx.configs.constants import NotificationType
@@ -247,10 +248,18 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
default=None,
)
chat_background: Mapped[str | None] = mapped_column(String, nullable=True)
default_app_mode: Mapped[DefaultAppMode] = mapped_column(
Enum(DefaultAppMode, native_enum=False),
nullable=False,
default=DefaultAppMode.CHAT,
)
# personalization fields are exposed via the chat user settings "Personalization" tab
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
enable_memory_tool: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
)
user_preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
@@ -312,6 +321,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
order_by="desc(Memory.id)",
)
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
"OAuthUserToken",
@@ -1027,6 +1037,31 @@ class OpenSearchTenantMigrationRecord(Base):
onupdate=func.now(),
nullable=False,
)
# Opaque continuation token from Vespa's Visit API.
# NULL means "not started" or "visit completed".
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
Text, nullable=True
)
total_chunks_migrated: Mapped[int] = mapped_column(
Integer, default=0, nullable=False
)
total_chunks_errored: Mapped[int] = mapped_column(
Integer, default=0, nullable=False
)
total_chunks_in_vespa: Mapped[int] = mapped_column(
Integer, default=0, nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
migration_completed_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
class KGEntityType(Base):
@@ -4842,3 +4877,90 @@ class BuildMessage(Base):
"ix_build_message_session_turn", "session_id", "turn_index", "created_at"
),
)
"""
SCIM 2.0 Provisioning Models (Enterprise Edition only)
Used for automated user/group provisioning from identity providers (Okta, Azure AD).
"""
class ScimToken(Base):
"""Bearer tokens for IdP SCIM authentication."""
__tablename__ = "scim_token"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, nullable=False)
hashed_token: Mapped[str] = mapped_column(
String(64), unique=True, nullable=False
) # SHA256 = 64 hex chars
token_display: Mapped[str] = mapped_column(
String, nullable=False
) # Last 4 chars for UI identification
created_by_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
is_active: Mapped[bool] = mapped_column(
Boolean, server_default=text("true"), nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
last_used_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_by: Mapped[User] = relationship("User", foreign_keys=[created_by_id])
class ScimUserMapping(Base):
"""Maps SCIM externalId from the IdP to an Onyx User."""
__tablename__ = "scim_user_mapping"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
user: Mapped[User] = relationship("User", foreign_keys=[user_id])
class ScimGroupMapping(Base):
"""Maps SCIM externalId from the IdP to an Onyx UserGroup."""
__tablename__ = "scim_group_mapping"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id", ondelete="CASCADE"), unique=True, nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
user_group: Mapped[UserGroup] = relationship(
"UserGroup", foreign_keys=[user_group_id]
)

View File

@@ -4,6 +4,9 @@ This module provides functions to track the progress of migrating documents
from Vespa to OpenSearch.
"""
from datetime import datetime
from datetime import timezone
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert
@@ -12,10 +15,14 @@ from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
from onyx.db.enums import OpenSearchDocumentMigrationStatus
from onyx.db.models import Document
from onyx.db.models import OpenSearchDocumentMigrationRecord
from onyx.db.models import OpenSearchTenantMigrationRecord
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -176,7 +183,7 @@ def try_insert_opensearch_tenant_migration_record_with_commit(
) -> None:
"""Tries to insert the singleton row on OpenSearchTenantMigrationRecord.
If the row already exists, does nothing.
Does nothing if the row already exists.
"""
stmt = insert(OpenSearchTenantMigrationRecord).on_conflict_do_nothing(
index_elements=[text("(true)")]
@@ -190,25 +197,14 @@ def increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
) -> None:
"""Increments the number of times observed no additional docs to migrate.
Tries to insert the singleton row on OpenSearchTenantMigrationRecord with a
starting count, and if the row already exists, increments the count.
Requires the OpenSearchTenantMigrationRecord to exist.
Used to track when to stop the migration task.
"""
stmt = (
insert(OpenSearchTenantMigrationRecord)
.values(num_times_observed_no_additional_docs_to_migrate=1)
.on_conflict_do_update(
index_elements=[text("(true)")],
set_={
"num_times_observed_no_additional_docs_to_migrate": (
OpenSearchTenantMigrationRecord.num_times_observed_no_additional_docs_to_migrate
+ 1
)
},
)
)
db_session.execute(stmt)
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.num_times_observed_no_additional_docs_to_migrate += 1
db_session.commit()
@@ -219,25 +215,14 @@ def increment_num_times_observed_no_additional_docs_to_populate_migration_table_
Increments the number of times observed no additional docs to populate the
migration table.
Tries to insert the singleton row on OpenSearchTenantMigrationRecord with a
starting count, and if the row already exists, increments the count.
Requires the OpenSearchTenantMigrationRecord to exist.
Used to track when to stop the migration check task.
"""
stmt = (
insert(OpenSearchTenantMigrationRecord)
.values(num_times_observed_no_additional_docs_to_populate_migration_table=1)
.on_conflict_do_update(
index_elements=[text("(true)")],
set_={
"num_times_observed_no_additional_docs_to_populate_migration_table": (
OpenSearchTenantMigrationRecord.num_times_observed_no_additional_docs_to_populate_migration_table
+ 1
)
},
)
)
db_session.execute(stmt)
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.num_times_observed_no_additional_docs_to_populate_migration_table += 1
db_session.commit()
@@ -254,3 +239,167 @@ def should_document_migration_be_permanently_failed(
>= TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE
)
)
def get_vespa_visit_state(
db_session: Session,
) -> tuple[str | None, int]:
"""Gets the current Vespa migration state from the tenant migration record.
Requires the OpenSearchTenantMigrationRecord to exist.
Returns:
Tuple of (continuation_token, total_chunks_migrated). continuation_token
is None if not started or completed.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
return (
record.vespa_visit_continuation_token,
record.total_chunks_migrated,
)
def update_vespa_visit_progress_with_commit(
db_session: Session,
continuation_token: str | None,
chunks_processed: int,
chunks_errored: int,
) -> None:
"""Updates the Vespa migration progress and commits.
Requires the OpenSearchTenantMigrationRecord to exist.
Args:
db_session: SQLAlchemy session.
continuation_token: The new continuation token. None means the visit
is complete.
chunks_processed: Number of chunks processed in this batch (added to
the running total).
chunks_errored: Number of chunks errored in this batch (added to the
running errored total).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.vespa_visit_continuation_token = continuation_token
record.total_chunks_migrated += chunks_processed
record.total_chunks_errored += chunks_errored
db_session.commit()
def mark_migration_completed_time_if_not_set_with_commit(
db_session: Session,
) -> None:
"""Marks the migration completed time if not set.
Requires the OpenSearchTenantMigrationRecord to exist.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
if record.migration_completed_at is not None:
return
record.migration_completed_at = datetime.now(timezone.utc)
db_session.commit()
def build_sanitized_to_original_doc_id_mapping(
db_session: Session,
) -> dict[str, str]:
"""Pre-computes a mapping of sanitized -> original document IDs.
Only includes documents whose ID contains single quotes (the only character
that gets sanitized by replace_invalid_doc_id_characters). For all other
documents, sanitized == original and no mapping entry is needed.
Scans over all documents.
Checks if the sanitized ID already exists as a genuine separate document in
the Document table. If so, raises as there is no way of resolving the
conflict in the migration. The user will need to reindex.
Args:
db_session: SQLAlchemy session.
Returns:
Dict mapping sanitized_id -> original_id, only for documents where
the IDs differ. Empty dict means no documents have single quotes
in their IDs.
"""
# Find all documents with single quotes in their ID.
stmt = select(Document.id).where(Document.id.contains("'"))
ids_with_quotes = list(db_session.scalars(stmt).all())
result: dict[str, str] = {}
for original_id in ids_with_quotes:
sanitized_id = replace_invalid_doc_id_characters(original_id)
if sanitized_id != original_id:
result[sanitized_id] = original_id
# See if there are any documents whose ID is a sanitized ID of another
# document. If there is even one match, we cannot proceed.
stmt = select(Document.id).where(Document.id.in_(result.keys()))
ids_with_matches = list(db_session.scalars(stmt).all())
if ids_with_matches:
raise RuntimeError(
f"Documents with IDs {ids_with_matches} have sanitized IDs that match other documents. "
"This is not supported and the user will need to reindex."
)
return result
def get_opensearch_migration_state(
db_session: Session,
) -> tuple[int, datetime | None, datetime | None]:
"""Returns the state of the Vespa to OpenSearch migration.
If the tenant migration record is not found, returns defaults of 0, None,
None.
Args:
db_session: SQLAlchemy session.
Returns:
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
return 0, None, None
return (
record.total_chunks_migrated,
record.created_at,
record.migration_completed_at,
)
def get_opensearch_retrieval_state(
db_session: Session,
) -> bool:
"""Returns the state of the OpenSearch retrieval.
If the tenant migration record is not found, defaults to
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
return ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
return record.enable_opensearch_retrieval
def set_enable_opensearch_retrieval_with_commit(
db_session: Session,
enable: bool,
) -> None:
"""Sets the enable_opensearch_retrieval flag on the singleton record.
Creates the record if it doesn't exist yet.
"""
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.enable_opensearch_retrieval = enable
db_session.commit()

View File

@@ -15,6 +15,8 @@ from onyx.db.index_attempt import (
count_unique_active_cc_pairs_with_successful_index_attempts,
)
from onyx.db.index_attempt import count_unique_cc_pairs_with_successful_index_attempts
from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import update_no_default_contextual_rag_provider
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
@@ -80,6 +82,24 @@ def _perform_index_swap(
db_session=db_session,
)
# Update the default contextual model to match the newly promoted settings
try:
update_default_contextual_model(
db_session=db_session,
enable_contextual_rag=new_search_settings.enable_contextual_rag,
contextual_rag_llm_provider=new_search_settings.contextual_rag_llm_provider,
contextual_rag_llm_name=new_search_settings.contextual_rag_llm_name,
)
except ValueError as e:
logger.error(f"Model not found, defaulting to no contextual model: {e}")
update_no_default_contextual_rag_provider(
db_session=db_session,
)
new_search_settings.enable_contextual_rag = False
new_search_settings.contextual_rag_llm_provider = None
new_search_settings.contextual_rag_llm_name = None
db_session.commit()
# This flow is for checking and possibly creating an index so we get all
# indices.
document_indices = get_all_document_indices(new_search_settings, None, None)

View File

@@ -55,6 +55,8 @@ def get_tools(
# To avoid showing rows that have JSON literal `null` stored in the column to the user.
# tools from mcp servers will not have an openapi schema but it has `null`, so we need to exclude them.
func.jsonb_typeof(Tool.openapi_schema) == "object",
# Exclude built-in tools that happen to have an openapi_schema
Tool.in_code_tool_id.is_(None),
)
return list(db_session.scalars(query).all())

View File

@@ -9,11 +9,13 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.enums import DefaultAppMode
from onyx.db.enums import ThemePreference
from onyx.db.models import AccessToken
from onyx.db.models import Assistant__UserSpecificConfig
from onyx.db.models import Memory
from onyx.db.models import User
from onyx.server.manage.models import MemoryItem
from onyx.server.manage.models import UserSpecificAssistantPreference
from onyx.utils.logger import setup_logger
@@ -153,13 +155,28 @@ def update_user_chat_background(
db_session.commit()
def update_user_default_app_mode(
user_id: UUID,
default_app_mode: DefaultAppMode,
db_session: Session,
) -> None:
"""Update user's default app mode setting."""
db_session.execute(
update(User)
.where(User.id == user_id) # type: ignore
.values(default_app_mode=default_app_mode)
)
db_session.commit()
def update_user_personalization(
user_id: UUID,
*,
personal_name: str | None,
personal_role: str | None,
use_memories: bool,
memories: list[str],
enable_memory_tool: bool,
memories: list[MemoryItem],
user_preferences: str | None,
db_session: Session,
) -> None:
@@ -170,15 +187,39 @@ def update_user_personalization(
personal_name=personal_name,
personal_role=personal_role,
use_memories=use_memories,
enable_memory_tool=enable_memory_tool,
user_preferences=user_preferences,
)
)
db_session.execute(delete(Memory).where(Memory.user_id == user_id))
# ID-based upsert: use real DB IDs from the frontend to match memories.
incoming_ids = {m.id for m in memories if m.id is not None}
if memories:
# Delete existing rows not in the incoming set (scoped to user_id)
existing_memories = list(
db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
)
existing_ids = {mem.id for mem in existing_memories}
ids_to_delete = existing_ids - incoming_ids
if ids_to_delete:
db_session.execute(
delete(Memory).where(
Memory.id.in_(ids_to_delete),
Memory.user_id == user_id,
)
)
# Update existing rows whose IDs match
existing_by_id = {mem.id: mem for mem in existing_memories}
for item in memories:
if item.id is not None and item.id in existing_by_id:
existing_by_id[item.id].memory_text = item.content
# Create new rows for items without an ID
new_items = [m for m in memories if m.id is None]
if new_items:
db_session.add_all(
[Memory(user_id=user_id, memory_text=memory) for memory in memories]
[Memory(user_id=user_id, memory_text=item.content) for item in new_items]
)
db_session.commit()

View File

@@ -17,6 +17,7 @@ from onyx.chat.llm_loop import construct_message_history
from onyx.chat.llm_step import run_llm_step
from onyx.chat.llm_step import run_llm_step_pkt_generator
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import LlmStepResult
from onyx.chat.models import ToolCallSimple
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
@@ -109,6 +110,7 @@ def generate_final_report(
user_identity: LLMUserIdentity | None,
saved_reasoning: str | None = None,
pre_answer_processing_time: float | None = None,
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
) -> bool:
"""Generate the final research report.
@@ -130,7 +132,7 @@ def generate_final_report(
reminder_message = ChatMessageSimple(
message=final_reminder,
token_count=token_counter(final_reminder),
message_type=MessageType.USER,
message_type=MessageType.USER_REMINDER,
)
final_report_history = construct_message_history(
system_prompt=system_prompt,
@@ -139,6 +141,7 @@ def generate_final_report(
reminder_message=reminder_message,
project_files=None,
available_tokens=llm.config.max_input_tokens,
all_injected_file_metadata=all_injected_file_metadata,
)
citation_processor = DynamicCitationProcessor()
@@ -194,6 +197,7 @@ def run_deep_research_llm_loop(
skip_clarification: bool = False,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
) -> None:
with trace(
"run_deep_research_llm_loop",
@@ -256,6 +260,7 @@ def run_deep_research_llm_loop(
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
all_injected_file_metadata=all_injected_file_metadata,
)
# Calculate tool processing duration for clarification step
@@ -304,6 +309,8 @@ def run_deep_research_llm_loop(
token_count=300,
message_type=MessageType.SYSTEM,
)
# Note this is fine to use a USER message type here as it can just be interpretered as a
# user's message directly to the LLM.
reminder_message = ChatMessageSimple(
message=RESEARCH_PLAN_REMINDER,
token_count=100,
@@ -317,6 +324,7 @@ def run_deep_research_llm_loop(
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
all_injected_file_metadata=all_injected_file_metadata,
)
research_plan_generator = run_llm_step_pkt_generator(
@@ -442,6 +450,7 @@ def run_deep_research_llm_loop(
citation_mapping=citation_mapping,
user_identity=user_identity,
pre_answer_processing_time=elapsed_seconds,
all_injected_file_metadata=all_injected_file_metadata,
)
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
@@ -450,11 +459,9 @@ def run_deep_research_llm_loop(
first_cycle_reminder_message = ChatMessageSimple(
message=FIRST_CYCLE_REMINDER,
token_count=FIRST_CYCLE_REMINDER_TOKENS,
message_type=MessageType.USER,
message_type=MessageType.USER_REMINDER,
)
first_cycle_tokens = FIRST_CYCLE_REMINDER_TOKENS
else:
first_cycle_tokens = 0
first_cycle_reminder_message = None
research_agent_calls: list[ToolCallKickoff] = []
@@ -477,15 +484,13 @@ def run_deep_research_llm_loop(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
reminder_message=first_cycle_reminder_message,
project_files=None,
available_tokens=available_tokens - first_cycle_tokens,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
all_injected_file_metadata=all_injected_file_metadata,
)
if first_cycle_reminder_message is not None:
truncated_message_history.append(first_cycle_reminder_message)
# Use think tool processor for non-reasoning models to convert
# think_tool calls to reasoning content
custom_processor = (
@@ -549,6 +554,7 @@ def run_deep_research_llm_loop(
user_identity=user_identity,
pre_answer_processing_time=time.monotonic()
- processing_start_time,
all_injected_file_metadata=all_injected_file_metadata,
)
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
@@ -572,6 +578,7 @@ def run_deep_research_llm_loop(
saved_reasoning=most_recent_reasoning,
pre_answer_processing_time=time.monotonic()
- processing_start_time,
all_injected_file_metadata=all_injected_file_metadata,
)
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
@@ -644,6 +651,7 @@ def run_deep_research_llm_loop(
user_identity=user_identity,
pre_answer_processing_time=time.monotonic()
- processing_start_time,
all_injected_file_metadata=all_injected_file_metadata,
)
final_turn_index = report_turn_index + (
1 if report_reasoned else 0

View File

@@ -0,0 +1,151 @@
"""A DocumentIndex implementation that raises on every operation.
Used as a safety net when DISABLE_VECTOR_DB is True. Any code path that
accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from typing import Any
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import QueryExpansionType
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import IndexBatchParams
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.model_server_models import Embedding
VECTOR_DB_DISABLED_ERROR = (
"Vector DB is disabled (DISABLE_VECTOR_DB=true). "
"This operation requires a vector database."
)
class DisabledDocumentIndex(DocumentIndex):
"""A DocumentIndex where every method raises RuntimeError.
Returned by the factory when DISABLE_VECTOR_DB is True so that any
accidental vector-DB call surfaces immediately.
"""
def __init__(
self,
index_name: str = "disabled",
secondary_index_name: str | None = None,
*args: Any, # noqa: ARG002
**kwargs: Any, # noqa: ARG002
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
# ------------------------------------------------------------------
# Verifiable
# ------------------------------------------------------------------
def ensure_indices_exist(
self,
primary_embedding_dim: int, # noqa: ARG002
primary_embedding_precision: EmbeddingPrecision, # noqa: ARG002
secondary_index_embedding_dim: int | None, # noqa: ARG002
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
) -> None:
# No-op: there are no indices to create when the vector DB is disabled.
pass
@staticmethod
def register_multitenant_indices(
indices: list[str], # noqa: ARG002, ARG004
embedding_dims: list[int], # noqa: ARG002, ARG004
embedding_precisions: list[EmbeddingPrecision], # noqa: ARG002, ARG004
) -> None:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
# ------------------------------------------------------------------
# Indexable
# ------------------------------------------------------------------
def index(
self,
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
# ------------------------------------------------------------------
# Deletable
# ------------------------------------------------------------------
def delete_single(
self,
doc_id: str, # noqa: ARG002
*,
tenant_id: str, # noqa: ARG002
chunk_count: int | None, # noqa: ARG002
) -> int:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
# ------------------------------------------------------------------
# Updatable
# ------------------------------------------------------------------
def update_single(
self,
doc_id: str, # noqa: ARG002
*,
tenant_id: str, # noqa: ARG002
chunk_count: int | None, # noqa: ARG002
fields: VespaDocumentFields | None, # noqa: ARG002
user_fields: VespaDocumentUserFields | None, # noqa: ARG002
) -> None:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
# ------------------------------------------------------------------
# IdRetrievalCapable
# ------------------------------------------------------------------
def id_based_retrieval(
self,
chunk_requests: list[VespaChunkRequest], # noqa: ARG002
filters: IndexFilters, # noqa: ARG002
batch_retrieval: bool = False, # noqa: ARG002
) -> list[InferenceChunk]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
# ------------------------------------------------------------------
# HybridCapable
# ------------------------------------------------------------------
def hybrid_retrieval(
self,
query: str, # noqa: ARG002
query_embedding: Embedding, # noqa: ARG002
final_keywords: list[str] | None, # noqa: ARG002
filters: IndexFilters, # noqa: ARG002
hybrid_alpha: float, # noqa: ARG002
time_decay_multiplier: float, # noqa: ARG002
num_to_retrieve: int, # noqa: ARG002
ranking_profile_type: QueryExpansionType, # noqa: ARG002
title_content_ratio: float | None = None, # noqa: ARG002
) -> list[InferenceChunk]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
# ------------------------------------------------------------------
# AdminCapable
# ------------------------------------------------------------------
def admin_retrieval(
self,
query: str, # noqa: ARG002
query_embedding: Embedding, # noqa: ARG002
filters: IndexFilters, # noqa: ARG002
num_to_retrieve: int = 10, # noqa: ARG002
) -> list[InferenceChunk]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
# ------------------------------------------------------------------
# RandomCapable
# ------------------------------------------------------------------
def random_retrieval(
self,
filters: IndexFilters, # noqa: ARG002
num_to_retrieve: int = 10, # noqa: ARG002
) -> list[InferenceChunk]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,8 +1,11 @@
import httpx
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
from onyx.db.models import SearchSettings
from onyx.db.opensearch_migration import get_opensearch_retrieval_state
from onyx.document_index.disabled import DisabledDocumentIndex
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
@@ -14,6 +17,7 @@ from shared_configs.configs import MULTI_TENANT
def get_default_document_index(
search_settings: SearchSettings,
secondary_search_settings: SearchSettings | None,
db_session: Session,
httpx_client: httpx.Client | None = None,
) -> DocumentIndex:
"""Gets the default document index from env vars.
@@ -27,13 +31,24 @@ def get_default_document_index(
index is for when both the currently used index and the upcoming index both
need to be updated, updates are applied to both indices.
"""
if DISABLE_VECTOR_DB:
return DisabledDocumentIndex(
index_name=search_settings.index_name,
secondary_index_name=(
secondary_search_settings.index_name
if secondary_search_settings
else None
),
)
secondary_index_name: str | None = None
secondary_large_chunks_enabled: bool | None = None
if secondary_search_settings:
secondary_index_name = secondary_search_settings.index_name
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
if ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX:
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
secondary_index_name=secondary_index_name,
@@ -69,7 +84,24 @@ def get_all_document_indices(
Large chunks and secondary indices are not currently supported so we
hardcode appropriate values.
NOTE: Make sure the Vespa index object is returned first. In the rare event
that there is some conflict between indexing and the migration task, it is
assumed that the state of Vespa is more up-to-date than the state of
OpenSearch.
"""
if DISABLE_VECTOR_DB:
return [
DisabledDocumentIndex(
index_name=search_settings.index_name,
secondary_index_name=(
secondary_search_settings.index_name
if secondary_search_settings
else None
),
)
]
vespa_document_index = VespaIndex(
index_name=search_settings.index_name,
secondary_index_name=(

View File

@@ -9,6 +9,7 @@ from opensearchpy import TransportError
from opensearchpy.helpers import bulk
from pydantic import BaseModel
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_ADMIN_PASSWORD
from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
from onyx.configs.app_configs import OPENSEARCH_HOST
@@ -21,6 +22,9 @@ from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
CLIENT_THRESHOLD_TO_LOG_SLOW_SEARCH_MS = 2000
logger = setup_logger(__name__)
# Set the logging level to WARNING to ignore INFO and DEBUG logs from
# opensearch. By default it emits INFO-level logs for every request.
@@ -52,6 +56,30 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
match_highlights: dict[str, list[str]] = {}
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
"""Recursively replaces vectors in the body with their length.
TODO(andrei): Do better.
Args:
body: The body to replace the vectors.
Returns:
A copy of body with vectors replaced with their length.
"""
new_body: dict[str, Any] = {}
for k, v in body.items():
if k == "vector":
new_body[k] = len(v)
elif isinstance(v, dict):
new_body[k] = get_new_body_without_vectors(v)
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
new_body[k] = [get_new_body_without_vectors(item) for item in v]
else:
new_body[k] = v
return new_body
class OpenSearchClient:
"""Client for interacting with OpenSearch.
@@ -74,10 +102,11 @@ class OpenSearchClient:
use_ssl: bool = True,
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
self._index_name = index_name
logger.debug(
f"Creating OpenSearch client for index {index_name} with host {host} and port {port}."
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
)
self._client = OpenSearch(
hosts=[{"host": host, "port": port}],
@@ -85,6 +114,13 @@ class OpenSearchClient:
use_ssl=use_ssl,
verify_certs=verify_certs,
ssl_show_warn=ssl_show_warn,
# NOTE: This timeout applies to all requests the client makes,
# including bulk indexing. When exceeded, the client will raise a
# ConnectionTimeout and return no useful results. The OpenSearch
# server will log that the client cancelled the request. To get
# partial results from OpenSearch, pass in a timeout parameter to
# your request body that is less than this value.
timeout=timeout,
)
logger.debug(
f"OpenSearch client created successfully for index {self._index_name}."
@@ -635,14 +671,31 @@ class OpenSearchClient:
f"Trying to search index {self._index_name} with search pipeline {search_pipeline_id}."
)
result: dict[str, Any]
params = {"phase_took": "true"}
if search_pipeline_id:
result = self._client.search(
index=self._index_name, search_pipeline=search_pipeline_id, body=body
index=self._index_name,
search_pipeline=search_pipeline_id,
body=body,
params=params,
)
else:
result = self._client.search(index=self._index_name, body=body)
result = self._client.search(
index=self._index_name, body=body, params=params
)
hits = self._get_hits_from_search_result(result)
hits, time_took, timed_out, phase_took, profile = (
self._get_hits_and_profile_from_search_result(result)
)
self._log_search_result_perf(
time_took=time_took,
timed_out=timed_out,
phase_took=phase_took,
profile=profile,
body=body,
search_pipeline_id=search_pipeline_id,
raise_on_timeout=True,
)
search_hits: list[SearchHit[DocumentChunk]] = []
for hit in hits:
@@ -698,9 +751,22 @@ class OpenSearchClient:
'"_source": False. This query will therefore be inefficient.'
)
result: dict[str, Any] = self._client.search(index=self._index_name, body=body)
params = {"phase_took": "true"}
result: dict[str, Any] = self._client.search(
index=self._index_name, body=body, params=params
)
hits = self._get_hits_from_search_result(result)
hits, time_took, timed_out, phase_took, profile = (
self._get_hits_and_profile_from_search_result(result)
)
self._log_search_result_perf(
time_took=time_took,
timed_out=timed_out,
phase_took=phase_took,
profile=profile,
body=body,
raise_on_timeout=True,
)
# TODO(andrei): Implement scroll/point in time for results so that we
# can return arbitrarily-many IDs.
@@ -737,34 +803,24 @@ class OpenSearchClient:
self._client.indices.refresh(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def set_cluster_auto_create_index_setting(self, enabled: bool) -> bool:
"""Sets the cluster auto create index setting.
By default, when you index a document to a non-existent index,
OpenSearch will automatically create the index. This behavior is
undesirable so this function exposes the ability to disable it.
See
https://docs.opensearch.org/latest/install-and-configure/configuring-opensearch/index/#updating-cluster-settings-using-the-api
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
enabled: Whether to enable the auto create index setting.
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the setting was updated successfully, False otherwise. Does
not raise.
True if the settings were put successfully, False otherwise.
"""
try:
body = {"persistent": {"action.auto_create_index": enabled}}
response = self._client.cluster.put_settings(body=body)
if response.get("acknowledged", False):
logger.info(f"Successfully set action.auto_create_index to {enabled}.")
return True
else:
logger.error(f"Failed to update setting: {response}.")
return False
except Exception:
logger.exception("Error setting auto_create_index.")
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
@@ -788,28 +844,78 @@ class OpenSearchClient:
"""
self._client.close()
def _get_hits_from_search_result(self, result: dict[str, Any]) -> list[Any]:
"""Extracts the hits from a search result.
def _get_hits_and_profile_from_search_result(
self, result: dict[str, Any]
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
"""Extracts the hits and profiling information from a search result.
Args:
result: The search result to extract the hits from.
Raises:
Exception: There was an error extracting the hits from the search
result. This includes the case where the search timed out.
result.
Returns:
The hits from the search result.
A tuple containing the hits from the search result, the time taken
to execute the search in milliseconds, whether the search timed
out, the time taken to execute each phase of the search, and the
profile.
"""
if result.get("timed_out", False):
raise RuntimeError(f"Search timed out for index {self._index_name}.")
time_took: int | None = result.get("took")
timed_out: bool | None = result.get("timed_out")
phase_took: dict[str, Any] = result.get("phase_took", {})
profile: dict[str, Any] = result.get("profile", {})
hits_first_layer: dict[str, Any] = result.get("hits", {})
if not hits_first_layer:
raise RuntimeError(
f"Hits field missing from response when trying to search index {self._index_name}."
)
hits_second_layer: list[Any] = hits_first_layer.get("hits", [])
return hits_second_layer
return hits_second_layer, time_took, timed_out, phase_took, profile
def _log_search_result_perf(
self,
time_took: int | None,
timed_out: bool | None,
phase_took: dict[str, Any],
profile: dict[str, Any],
body: dict[str, Any],
search_pipeline_id: str | None = None,
raise_on_timeout: bool = False,
) -> None:
"""Logs the performance of a search result.
Args:
time_took: The time taken to execute the search in milliseconds.
timed_out: Whether the search timed out.
phase_took: The time taken to execute each phase of the search.
profile: The profile for the search.
body: The body of the search request for logging.
search_pipeline_id: The ID of the search pipeline used for the
search, if any, for logging. Defaults to None.
raise_on_timeout: Whether to raise an exception if the search timed
out. Note that the result may still contain useful partial
results. Defaults to False.
Raises:
Exception: If raise_on_timeout is True and the search timed out.
"""
if time_took and time_took > CLIENT_THRESHOLD_TO_LOG_SLOW_SEARCH_MS:
logger.warning(
f"OpenSearch client warning: Search for index {self._index_name} took {time_took} milliseconds.\n"
f"Body: {get_new_body_without_vectors(body)}\n"
f"Search pipeline ID: {search_pipeline_id}\n"
f"Phase took: {phase_took}\n"
f"Profile: {profile}\n"
)
if timed_out:
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."
logger.error(error_str)
if raise_on_timeout:
raise RuntimeError(error_str)
def wait_for_opensearch_with_timeout(

Some files were not shown because too many files have changed in this diff Show More