Compare commits

...

142 Commits

Author SHA1 Message Date
Evan Lohn
462537e2ee initial confluence hierarchy impl 2026-01-27 23:08:38 -08:00
Evan Lohn
e3a9b2a021 hierarchyfetching task implementation 2026-01-27 22:11:31 -08:00
Evan Lohn
7080b3d966 feat(filesys): creation of hierarchyfetching job (#7555) 2026-01-28 06:03:15 +00:00
Wenxi
adc3c86b16 feat(craft): allow closing LLM setup modal (#7925) 2026-01-28 05:58:09 +00:00
roshan
b110621b13 fix(craft): install script for craft-latest image (#7918) 2026-01-27 21:40:30 -08:00
Evan Lohn
a2dc752d14 feat(filesys): implement hierarchy injection into vector db chunks (#7548) 2026-01-28 05:29:15 +00:00
Wenxi
f7d47a6ca9 refactor: build/v1 to craft/v1 (#7924) 2026-01-28 05:07:50 +00:00
roshan
9cc71b71ee fix(craft): allow more lenient tag names (for versioning) (#7921) 2026-01-27 21:13:35 -08:00
Wenxi
f2bafd113a refactor: packet type processing and path sanitization (#7920) 2026-01-28 04:33:54 +00:00
roshan
bb00ebd4a8 fix(craft): block opencode.json read (#7846) 2026-01-28 04:29:07 +00:00
Evan Lohn
fda04aa6d2 feat(filesys): opensearch integration with hierarchy (#7429) 2026-01-28 04:04:30 +00:00
Yuhong Sun
285aae6f2f chore: README (#7919) 2026-01-27 19:45:13 -08:00
Yuhong Sun
b75b1019f3 chore: kg stuff in celery (#7908) 2026-01-28 03:36:31 +00:00
Evan Lohn
bbba32b989 feat(filesys): connect hierarchynode and assistant (#7428) 2026-01-28 03:28:47 +00:00
joachim-danswer
f06bf69956 fix(craft): new demo data & change of eng IC demo persona (#7917) 2026-01-28 03:10:54 +00:00
roshan
7d4fe480cc fix(craft): files directory works locally + kube (#7913) 2026-01-27 19:01:08 -08:00
Chris Weaver
7f5b512856 feat: craft ui improvements (#7916) 2026-01-28 02:52:39 +00:00
Wenxi
844a01f751 fix(craft): allow initializing non-visible models (#7915) 2026-01-28 02:49:51 +00:00
Evan Lohn
d64be385db feat(filesys): Connectors know about hierarchynodes (#7404) 2026-01-28 02:39:43 +00:00
roshan
d0518388d6 feat(craft): update github action for craft latest (#7910)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-01-27 18:45:44 -08:00
Justin Tahara
a7f6d5f535 chore(tracing): Adding more explicit Tracing to our callsites (#7911) 2026-01-28 01:44:09 +00:00
Wenxi
059e2869e6 feat: md preview scrollbar (#7909) 2026-01-28 01:35:43 +00:00
Chris Weaver
04d90fd496 fix: improve session recovery (#7912) 2026-01-28 01:30:49 +00:00
Nikolas Garza
7cd29f4892 feat(ee): improve license enforcement middleware (#7853) 2026-01-28 01:26:02 +00:00
roshan
c2b86efebf fix(craft): delete session ui (#7847) 2026-01-27 17:30:35 -08:00
Nikolas Garza
bc5835967e feat(ee): Add unified billing API (#7857) 2026-01-27 17:02:08 -08:00
Evan Lohn
c2b11cae01 feat(filesys): data models and migration (#7402) 2026-01-28 00:03:52 +00:00
Chris Weaver
cf17ba6a1c fix: db connection closed for craft (#7905) 2026-01-27 15:46:46 -08:00
Jamison Lahman
b03634ecaa chore(mypy): fix mypy cache issues switching between HEAD and release (#7732) 2026-01-27 23:29:51 +00:00
Wenxi
9a7e92464f fix: demo data toggle race condition (#7902) 2026-01-27 23:06:17 +00:00
Wenxi
09b2a69c82 chore: remove pyproject config for pypandoc mypy (#7894) 2026-01-27 22:31:41 +00:00
Jamison Lahman
c5c027c168 fix: sidebar items are title case (#7893) 2026-01-27 22:05:06 +00:00
Wenxi
882163a4ea feat: md rendering, docx conversion and download, output panel refresh refactor for all artifacts (#7892) 2026-01-27 21:58:06 +00:00
roshan
de83a9a6f0 feat(craft): better output formats (#7889) 2026-01-27 21:48:08 +00:00
Jamison Lahman
f73ce0632f fix(citations): enable citation sidebar w/ web_search-only assistants (#7888) 2026-01-27 20:55:12 +00:00
Justin Tahara
0b10b11af3 fix(redis): Adding more TTLs (#7886) 2026-01-27 20:31:54 +00:00
roshan
d9e3b657d0 fix(craft): only include org_info/ when demo data enabled (#7845) 2026-01-27 19:48:48 +00:00
Justin Tahara
f6e9928dc1 fix(llm): Hide private models from Agent Creation (#7873) 2026-01-27 19:44:13 +00:00
Justin Tahara
ca3179ad8d chore(pr): Add Cherry-pick check (#7805) 2026-01-27 19:31:10 +00:00
Nikolas Garza
5529829ff5 feat(ee): update api to claim license via cloud proxy (#7840) 2026-01-27 18:46:39 +00:00
Chris Weaver
bdc7f6c100 chore: specify sandbox version (#7870) 2026-01-27 10:49:39 -08:00
Wenxi
90f8656afa fix: connector details back button should nav back (#7869) 2026-01-27 18:36:41 +00:00
Wenxi
3c7d35a6e8 fix: remove posthog debug logs and adjust gitignore (#7868) 2026-01-27 18:36:14 +00:00
Nikolas Garza
40d58a37e3 feat(ee): enforce seat limits on user operations (#7504) 2026-01-27 18:12:09 +00:00
Justin Tahara
be3ecd9640 fix(helm): Updating Ingress Templates (#7864) 2026-01-27 17:21:01 +00:00
Chris Weaver
a6da511490 fix: pass in correct region to allow IRSA usage (#7865) 2026-01-27 17:20:25 +00:00
roshan
c7577ebe58 fix(craft): only insert onyx user context when demo data not enabled (#7841) 2026-01-27 17:13:33 +00:00
SubashMohan
b87078a4f5 feat(chat): Search over chats and projects (#7788) 2026-01-27 16:57:00 +00:00
Yuhong Sun
8a408e7023 fix: Project Creation (#7851) 2026-01-27 05:27:19 +00:00
Nikolas Garza
4c7b73a355 feat(ee): add proxy endpoints for self-hosted billing operations (#7819) 2026-01-27 03:57:04 +00:00
Wenxi
8e9cb94d4f fix: processing mode enum (#7849) 2026-01-26 19:09:04 -08:00
Wenxi
a21af4b906 fix: type ignore unrelated mypy for onyx craft head (#7843) 2026-01-26 18:26:53 -08:00
Chris Weaver
7f0ce0531f feat: Onyx Craft (#7484)
Co-authored-by: Wenxi <wenxi@onyx.app>
Co-authored by: joachim-danswer <joachim@danswer.ai>
Co-authored-by: rohoswagger <roshan@onyx.app>
2026-01-26 17:12:42 -08:00
acaprau
b631bfa656 feat(opensearch): Add separate index settings for AWS-managed OpenSearch; Add function for disabling index auto-creation (#7814) 2026-01-27 00:40:46 +00:00
Nikolas Garza
eca6b6bef2 feat(ee): add license public key file and improve signature verification (#7806) 2026-01-26 23:44:16 +00:00
Wenxi
51ef28305d fix: user count check (#7811) 2026-01-26 13:21:33 -08:00
Jamison Lahman
144030c5ca chore(vscode): add non-clean seeded db restore (#7795) 2026-01-26 08:55:19 -08:00
SubashMohan
a557d76041 feat(ui): add new icons and enhance FadeDiv, Modal, Tabs, ExpandableTextDisplay (#7563)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 10:26:09 +00:00
SubashMohan
605e808158 fix(layout): adjust footer margin and prevent page refresh on chatsession drop (#7759) 2026-01-26 04:45:40 +00:00
roshan
8fec88c90d chore(deployment): remove no auth option from setup script (#7784) 2026-01-26 04:42:45 +00:00
Yuhong Sun
e54969a693 fix: LiteLLM Azure models don't stream (#7761) 2026-01-25 07:46:51 +00:00
Raunak Bhagat
1da2b2f28f fix: Some new fixes that were discovered by AI reviewers during 2.9-hotfixing (#7757)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-25 04:44:30 +00:00
Nikolas Garza
eb7b91e08e fix(tests): use crawler-friendly search query in Exa integration test (#7746) 2026-01-24 20:58:02 +00:00
Yuhong Sun
3339000968 fix: Spacing issue on Feedback (#7747) 2026-01-24 12:59:00 -08:00
Nikolas Garza
d9db849e94 fix(chat): prevent streaming text from appearing in bursts after citations (#7745) 2026-01-24 11:48:34 -08:00
Yuhong Sun
046408359c fix: Azure OpenAI Tool Calls (#7727) 2026-01-24 01:47:03 +00:00
acaprau
4b8cca190f feat(opensearch): Implement complete retrieval filtering (#7691) 2026-01-23 23:27:42 +00:00
Justin Tahara
52a312a63b feat: onyx discord bot - supervisord and kube deployment (#7706) 2026-01-23 20:55:06 +00:00
Danelegend
0594fd17de chore(tests): add more packet tests (#7677) 2026-01-23 19:49:41 +00:00
Jamison Lahman
fded81dc28 chore(extensions): pull in chrome extension (#7703) 2026-01-23 10:17:05 -08:00
Danelegend
31db112de9 feat(url): Open url around snippet (#7488) 2026-01-23 17:02:38 +00:00
Jamison Lahman
a3e2da2c51 chore(vscode): add useful database operations (#7702) 2026-01-23 08:49:59 -08:00
Evan Lohn
f4d33bcc0d feat: basic user MCP action attaching (#7681) 2026-01-23 05:50:49 +00:00
Jamison Lahman
464d957494 chore(devtools): upgrade ods v0.4.0; vscode to restore seeded db (#7696) 2026-01-23 05:21:46 +00:00
Jamison Lahman
be12de9a44 chore(devtools): ods db restore --fetch-seeded (#7689) 2026-01-22 20:41:28 -08:00
Yuhong Sun
3e4a1f8a09 feat: Maintain correct docs on replay (#7683) 2026-01-22 19:24:10 -08:00
Raunak Bhagat
af9b7826ab fix: Remove cursor pointer from view-only field (#7688)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-01-23 02:47:08 +00:00
Danelegend
cb16eb13fc chore(tests): Mock LLM (#7590) 2026-01-23 01:48:54 +00:00
Jamison Lahman
20a73bdd2e chore(desktop): make artifact filename version-agnostic (#7679) 2026-01-22 15:15:52 -08:00
Justin Tahara
85cc2b99b7 fix(fastapi): Resolve CVE-2025-68481 (#7661) 2026-01-22 20:07:25 +00:00
Jamison Lahman
1208a3ee2b chore(fe): disable blur when there is not a custom background (#7673)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-01-22 11:26:16 -08:00
Justin Tahara
900fcef9dd feat(desktop): Domain Configuration (#7655) 2026-01-22 18:15:44 +00:00
Justin Tahara
d4ed25753b fix(ui): Coda Logo (#7656) 2026-01-22 10:10:02 -08:00
Justin Tahara
0ee58333b4 fix(ui): User Groups Connectors Fix (#7658) 2026-01-22 17:59:12 +00:00
Justin Tahara
11b7e0d571 fix(ui): First Connector Result (#7657) 2026-01-22 17:52:02 +00:00
acaprau
a35831f328 fix(opensearch): Release Onyx Helm Charts was failing (#7672) 2026-01-22 17:41:47 +00:00
Justin Tahara
048a6d5259 fix(ui): Fix Token Rate Limits Page (#7659) 2026-01-22 17:20:21 +00:00
Ciaran Sweet
e4bdb15910 docs: enhance send-chat-message docs to also show ChatFullResponse (#7430) 2026-01-22 16:48:26 +00:00
Jamison Lahman
3517d59286 chore(fe): add custom backgrounds to the settings page (#7668) 2026-01-21 21:32:56 -08:00
Jamison Lahman
4bc08e5d88 chore(fe): remove Text pseudo-element padding (#7665) 2026-01-21 19:50:42 -08:00
Yuhong Sun
4bd080cf62 chore: Redirect user to create account (#7654) 2026-01-22 02:44:58 +00:00
Raunak Bhagat
b0a8625ffc feat: Add confirmation modal for connector disconnect (#7637)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 02:08:19 +00:00
Yuhong Sun
f94baf6143 fix: DR Language Tuning (#7660) 2026-01-21 17:36:50 -08:00
Wenxi
9e1867638a feat: onyx discord bot - frontend (#7497) 2026-01-22 00:00:12 +00:00
Yuhong Sun
5b6d7c9f0d chore: Onboarding Image Generation (#7653) 2026-01-21 15:49:15 -08:00
Danelegend
e5dcf31f10 fix(image): Emit error to user (#7644) 2026-01-21 22:50:12 +00:00
Nikolas Garza
8ca06ef3e7 fix: deflake chat user journey test (#7646) 2026-01-21 22:33:30 +00:00
Justin Tahara
6897dbd610 feat(desktop): Properly Sign Mac App (#7608) 2026-01-21 22:17:45 +00:00
Evan Lohn
7f3cb77466 chore: remove prompt caching from chat history (#7636) 2026-01-21 21:35:11 +00:00
acaprau
267042a5aa fix(opensearch): Use the same method for getting title that the title embedding logic uses; small cleanup for content embedding (#7638) 2026-01-21 21:34:38 +00:00
Yuhong Sun
d02b3ae6ac chore: Remove default prompt shortcuts (#7639) 2026-01-21 21:28:53 +00:00
Yuhong Sun
683c3f7a7e fix: color mode and memories (#7642) 2026-01-21 13:29:33 -08:00
Nikolas Garza
008b4d2288 fix(slack): Extract person names and filter garbage in query expansion (#7632) 2026-01-21 21:09:50 +00:00
Jamison Lahman
8be261405a chore(deployments): fix region (#7640) 2026-01-21 13:14:42 -08:00
acaprau
61f2c48ebc feat(opensearch): Add helm charts (#7606) 2026-01-21 19:34:18 +00:00
acaprau
dbde2e6d6d chore(opensearch): Create OpenSearch docker compose, enabling test_opensearch_client.py to run in CI (#7611) 2026-01-21 18:41:23 +00:00
Raunak Bhagat
2860136214 feat: Refreshed user settings page (#7455)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 16:41:56 +00:00
Raunak Bhagat
49ec5994d3 refactor: Improve refresh-components with cleanup and truncation (#7622)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 00:29:25 -08:00
Raunak Bhagat
8d5fb67f0f feat: improve prompt shortcuts with uniqueness constraints and enhancements (#7619)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 07:31:35 +00:00
Raunak Bhagat
15d02f6e3c fix: Prevent description duplication in Modal header (#7609)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 04:32:22 +00:00
Jamison Lahman
e58974c419 chore(fe): move chatpage footer inside background element (#7618) 2026-01-21 04:21:49 +00:00
Yuhong Sun
6b66c07952 chore: Delete multilingual docker compose file (#7616) 2026-01-20 19:50:01 -08:00
Jamison Lahman
cae058a3ac chore(extensions): simplify and de-dupe NRFPage (#7607) 2026-01-21 03:42:19 +00:00
Nikolas Garza
aa3b21a191 fix: scroll to bottom when loading existing conversations (#7614) 2026-01-20 19:19:18 -08:00
Raunak Bhagat
7a07a78696 fix: Set width to fit for rightChildren section in LineItem (#7604)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 01:55:03 +00:00
Nikolas Garza
a8db236e37 feat(billing): fetch Stripe publishable key from S3 (#7595) 2026-01-21 01:32:57 +00:00
Raunak Bhagat
8a2e4ed36f fix: Fix flashing in progress-circle icon (#7605)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-21 01:03:52 +00:00
Evan Lohn
216f2c95a7 chore: add dialog description to modal (#7603) 2026-01-21 00:41:35 +00:00
Evan Lohn
67081efe08 fix: modal header in index attempt errors (#7601) 2026-01-21 00:37:23 +00:00
Yuhong Sun
9d40b8336f feat: Allow no system prompt (#7600) 2026-01-20 16:16:39 -08:00
Evan Lohn
23f0033302 chore: bg services launch.json (#7597) 2026-01-21 00:05:20 +00:00
Raunak Bhagat
9011b76eb0 refactor: Add new layout component (#7588)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 23:36:18 +00:00
Yuhong Sun
45e436bafc fix: prompt tunings (#7594) 2026-01-20 15:13:05 -08:00
Justin Tahara
010bc36d61 Revert "chore(deps): Bump fastapi-users from 14.0.1 to 15.0.2 in /backend/requirements" (#7593) 2026-01-20 14:44:21 -08:00
dependabot[bot]
468e488bdb chore(deps): bump docker/setup-buildx-action from 3.11.1 to 3.12.0 (#7527)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-20 22:36:39 +00:00
dependabot[bot]
9104c0ffce chore(deps): Bump fastapi-users from 14.0.1 to 15.0.2 in /backend/requirements (#6897)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: justin-tahara <justintahara@gmail.com>
2026-01-20 22:31:02 +00:00
Jamison Lahman
d36a6bd0b4 feat(fe): custom chat backgrounds (#7486)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-01-20 14:29:06 -08:00
Jamison Lahman
a3603c498c chore(deployments): fetch secrets from AWS (#7584) 2026-01-20 22:10:19 +00:00
Jamison Lahman
8f274e34c9 chore(blame): unignore checked in .vscode/ files (#7592) 2026-01-20 14:07:27 -08:00
Justin Tahara
5c256760ff fix(vertex ai): Extra Args for Opus 4.5 (#7586) 2026-01-20 14:07:14 -08:00
Nikolas Garza
258e1372b3 fix(billing): remove grandfathered pricing option when subscription lapses (#7583) 2026-01-20 21:55:37 +00:00
Yuhong Sun
83a543a265 chore: NLTK and stopwords (#7587) 2026-01-20 13:36:04 -08:00
Evan Lohn
f9719d199d fix: drive connector creation ui (#7578) 2026-01-20 21:10:06 +00:00
Raunak Bhagat
1c7bb6e56a fix: Input variant refactor (#7579)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 13:04:16 -08:00
acaprau
982ad7d329 feat(opensearch): Add dual document indices (#7539) 2026-01-20 20:53:24 +00:00
Jamison Lahman
f94292808b chore(vscode): launch.template.jsonc -> launch.json (#7440) 2026-01-20 20:32:46 +00:00
Justin Tahara
293553a2e2 fix(tests): Anthropic Prompt Caching Test (#7585) 2026-01-20 20:32:24 +00:00
Justin Tahara
ba906ae6fa chore(llm): Removing Claude Haiku 3.5 (#7577) 2026-01-20 19:06:14 +00:00
Raunak Bhagat
c84c7a354e refactor: refactor to use string-enum props instead of boolean props (#7575)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 18:59:54 +00:00
Jamison Lahman
2187b0dd82 chore(pre-commit): disallow large files (#7576) 2026-01-20 11:02:00 -08:00
acaprau
d88a417bf9 feat(opensearch): Formally disable secondary indices in the backend (#7541) 2026-01-20 18:21:47 +00:00
Jamison Lahman
f2d32b0b3b fix(fe): inline code text wraps (#7574) 2026-01-20 17:11:42 +00:00
819 changed files with 92159 additions and 7662 deletions

View File

@@ -8,4 +8,5 @@
## Additional Options
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Override Linear Check

File diff suppressed because it is too large Load Diff

View File

@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -29,6 +29,7 @@ jobs:
run: |
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add opensearch https://opensearch-project.github.io/helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/

View File

@@ -94,7 +94,7 @@ jobs:
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -0,0 +1,28 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -125,11 +128,13 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \
cache \
index \
opensearch \
code-interpreter
- name: Run migrations
@@ -158,7 +163,7 @@ jobs:
cd deployment/docker_compose
# Get list of running containers
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
# Collect logs from each container
for container in $containers; do

View File

@@ -88,6 +88,7 @@ jobs:
echo "=== Adding Helm repositories ==="
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add opensearch https://opensearch-project.github.io/helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
@@ -180,6 +181,11 @@ jobs:
trap cleanup EXIT
# Run the actual installation with detailed logging
# Note that opensearch.enabled is true whereas others in this install
# are false. There is some work that needs to be done to get this
# entire step working in CI, enabling opensearch here is a small step
# in that direction. If this is causing issues, disabling it in this
# step should be ok in the short term.
echo "=== Starting ct install ==="
set +e
ct install --all \
@@ -187,6 +193,8 @@ jobs:
--set=nginx.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=opensearch.enabled=true \
--set=auth.opensearch.enabled=true \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.nameOverride=cloudnative-pg \

View File

@@ -103,7 +103,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -163,7 +163,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -208,7 +208,7 @@ jobs:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit

View File

@@ -95,7 +95,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -155,7 +155,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -214,7 +214,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit

View File

@@ -85,7 +85,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -146,7 +146,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -207,7 +207,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/

View File

@@ -50,8 +50,9 @@ jobs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy

View File

@@ -70,7 +70,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
- name: Build and load
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6

3
.gitignore vendored
View File

@@ -1,5 +1,8 @@
# editors
.vscode
!/.vscode/env_template.txt
!/.vscode/launch.json
!/.vscode/tasks.template.jsonc
.zed
.cursor

View File

@@ -66,7 +66,8 @@ repos:
- id: uv-run
name: Check lazy imports
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
pass_filenames: true
files: ^backend/(?!\.venv/|scripts/).*\.py$
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
# - id: uv-run
# name: mypy
@@ -74,6 +75,13 @@ repos:
# pass_filenames: true
# files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
hooks:
- id: check-added-large-files
name: Check for added large files
args: ["--maxkb=1500"]
- repo: https://github.com/rhysd/actionlint
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
hooks:

View File

@@ -1,5 +1,3 @@
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
@@ -24,7 +22,7 @@
"Slack Bot",
"Celery primary",
"Celery light",
"Celery background",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery beat"
@@ -151,6 +149,24 @@
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Discord Bot",
"consoleName": "Discord Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/discord/client.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Discord Bot Console"
},
{
"name": "MCP Server",
"consoleName": "MCP Server",
@@ -399,7 +415,6 @@
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
@@ -430,7 +445,6 @@
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
@@ -579,6 +593,137 @@
"group": "3"
}
},
{
"name": "Build Sandbox Templates",
"type": "debugpy",
"request": "launch",
"module": "onyx.server.features.build.sandbox.build_templates",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"console": "integratedTerminal",
"presentation": {
"group": "3"
},
"consoleTitle": "Build Sandbox Templates"
},
{
// Dummy entry used to label the group
"name": "--- Database ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "4",
"order": 0
}
},
{
"name": "Restore seeded database dump",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--fetch-seeded",
"--yes"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Clean restore seeded database dump (destructive)",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--fetch-seeded",
"--clean",
"--yes"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Create database snapshot",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"dump",
"backup.dump"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Clean restore database snapshot (destructive)",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--clean",
"--yes",
"backup.dump"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Upgrade database to head revision",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"upgrade"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
// script to generate the openapi schema
"name": "Onyx OpenAPI Schema Generator",

View File

@@ -16,3 +16,8 @@ dist/
.coverage
htmlcov/
model_server/legacy/
# Craft: demo_data directory should be unzipped at container startup, not copied
**/demo_data/
# Craft: templates/outputs/venv is created at container startup
**/templates/outputs/venv

View File

@@ -37,10 +37,6 @@ CVE-2023-50868
CVE-2023-52425
CVE-2024-28757
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
# No impact in our settings
CVE-2023-7104
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
# Accept the risk
CVE-2023-25193

View File

@@ -7,6 +7,10 @@ have a contract or agreement with DanswerAI, you are not permitted to use the En
Edition features outside of personal development or testing purposes. Please reach out to \
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
# Build argument for Craft support (disabled by default)
# Use --build-arg ENABLE_CRAFT=true to include Node.js and opencode CLI
ARG ENABLE_CRAFT=false
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true" \
@@ -46,7 +50,23 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Conditionally install Node.js 20 for Craft (required for Next.js)
# Only installed when ENABLE_CRAFT=true
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
echo "Installing Node.js 20 for Craft support..." && \
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
apt-get install -y nodejs && \
rm -rf /var/lib/apt/lists/*; \
fi
# Conditionally install opencode CLI for Craft agent functionality
# Only installed when ENABLE_CRAFT=true
# TODO: download a specific, versioned release of the opencode CLI
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
echo "Installing opencode CLI for Craft support..." && \
curl -fsSL https://opencode.ai/install | bash; \
fi
ENV PATH="/root/.opencode/bin:${PATH}"
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
@@ -91,8 +111,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('punkt_tab', quiet=True);"
nltk.download('stopwords', quiet=True); \
nltk.download('punkt_tab', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Pre-downloading tiktoken for setups with limited egress
@@ -119,7 +139,15 @@ COPY --chown=onyx:onyx ./static /app/static
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
# Run Craft template setup at build time when ENABLE_CRAFT=true
# This pre-bakes demo data, Python venv, and npm dependencies into the image
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
echo "Running Craft template setup at build time..." && \
ENABLE_CRAFT=true /app/scripts/setup_craft_templates.sh; \
fi
# Put logo in assets
COPY --chown=onyx:onyx ./assets /app/assets

View File

@@ -0,0 +1,351 @@
"""single onyx craft migration
Consolidates all buildmode/onyx craft tables into a single migration.
Tables created:
- build_session: User build sessions with status tracking
- sandbox: User-owned containerized environments (one per user)
- artifact: Build output files (web apps, documents, images)
- snapshot: Sandbox filesystem snapshots
- build_message: Conversation messages for build sessions
Existing table modified:
- connector_credential_pair: Added processing_mode column
Revision ID: 2020d417ec84
Revises: 41fa44bef321
Create Date: 2026-01-26 14:43:54.641405
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "2020d417ec84"
down_revision = "41fa44bef321"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ==========================================================================
# ENUMS
# ==========================================================================
# Build session status enum
build_session_status_enum = sa.Enum(
"active",
"idle",
name="buildsessionstatus",
native_enum=False,
)
# Sandbox status enum
sandbox_status_enum = sa.Enum(
"provisioning",
"running",
"idle",
"sleeping",
"terminated",
"failed",
name="sandboxstatus",
native_enum=False,
)
# Artifact type enum
artifact_type_enum = sa.Enum(
"web_app",
"pptx",
"docx",
"markdown",
"excel",
"image",
name="artifacttype",
native_enum=False,
)
# ==========================================================================
# BUILD_SESSION TABLE
# ==========================================================================
op.create_table(
"build_session",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("name", sa.String(), nullable=True),
sa.Column(
"status",
build_session_status_enum,
nullable=False,
server_default="active",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"last_activity_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("nextjs_port", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_build_session_user_created",
"build_session",
["user_id", sa.text("created_at DESC")],
unique=False,
)
op.create_index(
"ix_build_session_status",
"build_session",
["status"],
unique=False,
)
# ==========================================================================
# SANDBOX TABLE (user-owned, one per user)
# ==========================================================================
op.create_table(
"sandbox",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("container_id", sa.String(), nullable=True),
sa.Column(
"status",
sandbox_status_enum,
nullable=False,
server_default="provisioning",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("last_heartbeat", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id", name="sandbox_user_id_key"),
)
op.create_index(
"ix_sandbox_status",
"sandbox",
["status"],
unique=False,
)
op.create_index(
"ix_sandbox_container_id",
"sandbox",
["container_id"],
unique=False,
)
# ==========================================================================
# ARTIFACT TABLE
# ==========================================================================
op.create_table(
"artifact",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"session_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("type", artifact_type_enum, nullable=False),
sa.Column("path", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_artifact_session_created",
"artifact",
["session_id", sa.text("created_at DESC")],
unique=False,
)
op.create_index(
"ix_artifact_type",
"artifact",
["type"],
unique=False,
)
# ==========================================================================
# SNAPSHOT TABLE
# ==========================================================================
op.create_table(
"snapshot",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"session_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("storage_path", sa.String(), nullable=False),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_snapshot_session_created",
"snapshot",
["session_id", sa.text("created_at DESC")],
unique=False,
)
# ==========================================================================
# BUILD_MESSAGE TABLE
# ==========================================================================
op.create_table(
"build_message",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"session_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"turn_index",
sa.Integer(),
nullable=False,
),
sa.Column(
"type",
sa.Enum(
"SYSTEM",
"USER",
"ASSISTANT",
"DANSWER",
name="messagetype",
create_type=False,
native_enum=False,
),
nullable=False,
),
sa.Column(
"message_metadata",
postgresql.JSONB(),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_build_message_session_turn",
"build_message",
["session_id", "turn_index", sa.text("created_at ASC")],
unique=False,
)
# ==========================================================================
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
# ==========================================================================
op.add_column(
"connector_credential_pair",
sa.Column(
"processing_mode",
sa.String(),
nullable=False,
server_default="regular",
),
)
def downgrade() -> None:
# ==========================================================================
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
# ==========================================================================
op.drop_column("connector_credential_pair", "processing_mode")
# ==========================================================================
# BUILD_MESSAGE TABLE
# ==========================================================================
op.drop_index("ix_build_message_session_turn", table_name="build_message")
op.drop_table("build_message")
# ==========================================================================
# SNAPSHOT TABLE
# ==========================================================================
op.drop_index("ix_snapshot_session_created", table_name="snapshot")
op.drop_table("snapshot")
# ==========================================================================
# ARTIFACT TABLE
# ==========================================================================
op.drop_index("ix_artifact_type", table_name="artifact")
op.drop_index("ix_artifact_session_created", table_name="artifact")
op.drop_table("artifact")
sa.Enum(name="artifacttype").drop(op.get_bind(), checkfirst=True)
# ==========================================================================
# SANDBOX TABLE
# ==========================================================================
op.drop_index("ix_sandbox_container_id", table_name="sandbox")
op.drop_index("ix_sandbox_status", table_name="sandbox")
op.drop_table("sandbox")
sa.Enum(name="sandboxstatus").drop(op.get_bind(), checkfirst=True)
# ==========================================================================
# BUILD_SESSION TABLE
# ==========================================================================
op.drop_index("ix_build_session_status", table_name="build_session")
op.drop_index("ix_build_session_user_created", table_name="build_session")
op.drop_table("build_session")
sa.Enum(name="buildsessionstatus").drop(op.get_bind(), checkfirst=True)

View File

@@ -0,0 +1,42 @@
"""add_unique_constraint_to_inputprompt_prompt_user_id
Revision ID: 2c2430828bdf
Revises: fb80bdd256de
Create Date: 2026-01-20 16:01:54.314805
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "2c2430828bdf"
down_revision = "fb80bdd256de"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create unique constraint on (prompt, user_id) for user-owned prompts
# This ensures each user can only have one shortcut with a given name
op.create_unique_constraint(
"uq_inputprompt_prompt_user_id",
"inputprompt",
["prompt", "user_id"],
)
# Create partial unique index for public prompts (where user_id IS NULL)
# PostgreSQL unique constraints don't enforce uniqueness for NULL values,
# so we need a partial index to ensure public prompt names are also unique
op.execute(
"""
CREATE UNIQUE INDEX uq_inputprompt_prompt_public
ON inputprompt (prompt)
WHERE user_id IS NULL
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS uq_inputprompt_prompt_public")
op.drop_constraint("uq_inputprompt_prompt_user_id", "inputprompt", type_="unique")

View File

@@ -0,0 +1,29 @@
"""remove default prompt shortcuts
Revision ID: 41fa44bef321
Revises: 2c2430828bdf
Create Date: 2025-01-21
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "41fa44bef321"
down_revision = "2c2430828bdf"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Delete any user associations for the default prompts first (foreign key constraint)
op.execute(
"DELETE FROM inputprompt__user WHERE input_prompt_id IN (SELECT id FROM inputprompt WHERE id < 0)"
)
# Delete the pre-seeded default prompt shortcuts (they have negative IDs)
op.execute("DELETE FROM inputprompt WHERE id < 0")
def downgrade() -> None:
# We don't restore the default prompts on downgrade
pass

View File

@@ -0,0 +1,45 @@
"""make processing mode default all caps
Revision ID: 72aa7de2e5cf
Revises: 2020d417ec84
Create Date: 2026-01-26 18:58:47.705253
This migration fixes the ProcessingMode enum value mismatch:
- SQLAlchemy's Enum with native_enum=False uses enum member NAMES as valid values
- The original migration stored lowercase VALUES ('regular', 'file_system')
- This converts existing data to uppercase NAMES ('REGULAR', 'FILE_SYSTEM')
- Also drops any spurious native PostgreSQL enum type that may have been auto-created
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "72aa7de2e5cf"
down_revision = "2020d417ec84"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Convert existing lowercase values to uppercase to match enum member names
op.execute(
"UPDATE connector_credential_pair SET processing_mode = 'REGULAR' "
"WHERE processing_mode = 'regular'"
)
op.execute(
"UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' "
"WHERE processing_mode = 'file_system'"
)
# Update the server default to use uppercase
op.alter_column(
"connector_credential_pair",
"processing_mode",
server_default="REGULAR",
)
def downgrade() -> None:
# State prior to this was broken, so we don't want to revert back to it
pass

View File

@@ -0,0 +1,349 @@
"""hierarchy_nodes_v1
Revision ID: 81c22b1e2e78
Revises: 72aa7de2e5cf
Create Date: 2026-01-13 18:10:01.021451
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from onyx.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "81c22b1e2e78"
down_revision = "72aa7de2e5cf"
branch_labels = None
depends_on = None
# Human-readable display names for each source
SOURCE_DISPLAY_NAMES: dict[str, str] = {
"ingestion_api": "Ingestion API",
"slack": "Slack",
"web": "Web",
"google_drive": "Google Drive",
"gmail": "Gmail",
"requesttracker": "Request Tracker",
"github": "GitHub",
"gitbook": "GitBook",
"gitlab": "GitLab",
"guru": "Guru",
"bookstack": "BookStack",
"outline": "Outline",
"confluence": "Confluence",
"jira": "Jira",
"slab": "Slab",
"productboard": "Productboard",
"file": "File",
"coda": "Coda",
"notion": "Notion",
"zulip": "Zulip",
"linear": "Linear",
"hubspot": "HubSpot",
"document360": "Document360",
"gong": "Gong",
"google_sites": "Google Sites",
"zendesk": "Zendesk",
"loopio": "Loopio",
"dropbox": "Dropbox",
"sharepoint": "SharePoint",
"teams": "Teams",
"salesforce": "Salesforce",
"discourse": "Discourse",
"axero": "Axero",
"clickup": "ClickUp",
"mediawiki": "MediaWiki",
"wikipedia": "Wikipedia",
"asana": "Asana",
"s3": "S3",
"r2": "R2",
"google_cloud_storage": "Google Cloud Storage",
"oci_storage": "OCI Storage",
"xenforo": "XenForo",
"not_applicable": "Not Applicable",
"discord": "Discord",
"freshdesk": "Freshdesk",
"fireflies": "Fireflies",
"egnyte": "Egnyte",
"airtable": "Airtable",
"highspot": "Highspot",
"drupal_wiki": "Drupal Wiki",
"imap": "IMAP",
"bitbucket": "Bitbucket",
"testrail": "TestRail",
"mock_connector": "Mock Connector",
"user_file": "User File",
}
def upgrade() -> None:
# 1. Create hierarchy_node table
op.create_table(
"hierarchy_node",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("raw_node_id", sa.String(), nullable=False),
sa.Column("display_name", sa.String(), nullable=False),
sa.Column("link", sa.String(), nullable=True),
sa.Column("source", sa.String(), nullable=False),
sa.Column("node_type", sa.String(), nullable=False),
sa.Column("document_id", sa.String(), nullable=True),
sa.Column("parent_id", sa.Integer(), nullable=True),
# Permission fields - same pattern as Document table
sa.Column(
"external_user_emails",
postgresql.ARRAY(sa.String()),
nullable=True,
),
sa.Column(
"external_user_group_ids",
postgresql.ARRAY(sa.String()),
nullable=True,
),
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="false"),
sa.PrimaryKeyConstraint("id"),
# When document is deleted, just unlink (node can exist without document)
sa.ForeignKeyConstraint(["document_id"], ["document.id"], ondelete="SET NULL"),
# When parent node is deleted, orphan children (cleanup via pruning)
sa.ForeignKeyConstraint(
["parent_id"], ["hierarchy_node.id"], ondelete="SET NULL"
),
sa.UniqueConstraint(
"raw_node_id", "source", name="uq_hierarchy_node_raw_id_source"
),
)
op.create_index("ix_hierarchy_node_parent_id", "hierarchy_node", ["parent_id"])
op.create_index(
"ix_hierarchy_node_source_type", "hierarchy_node", ["source", "node_type"]
)
# Add partial unique index to ensure only one SOURCE-type node per source
# This prevents duplicate source root nodes from being created
# NOTE: node_type stores enum NAME ('SOURCE'), not value ('source')
op.execute(
sa.text(
"""
CREATE UNIQUE INDEX uq_hierarchy_node_one_source_per_type
ON hierarchy_node (source)
WHERE node_type = 'SOURCE'
"""
)
)
# 2. Create hierarchy_fetch_attempt table
op.create_table(
"hierarchy_fetch_attempt",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("status", sa.String(), nullable=False),
sa.Column("nodes_fetched", sa.Integer(), nullable=True, server_default="0"),
sa.Column("nodes_updated", sa.Integer(), nullable=True, server_default="0"),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("full_exception_trace", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
ondelete="CASCADE",
),
)
op.create_index(
"ix_hierarchy_fetch_attempt_status", "hierarchy_fetch_attempt", ["status"]
)
op.create_index(
"ix_hierarchy_fetch_attempt_time_created",
"hierarchy_fetch_attempt",
["time_created"],
)
op.create_index(
"ix_hierarchy_fetch_attempt_cc_pair",
"hierarchy_fetch_attempt",
["connector_credential_pair_id"],
)
# 3. Insert SOURCE-type hierarchy nodes for each DocumentSource
# We insert these so every existing document can have a parent hierarchy node
# NOTE: SQLAlchemy's Enum with native_enum=False stores the enum NAME (e.g., 'GOOGLE_DRIVE'),
# not the VALUE (e.g., 'google_drive'). We must use .name for source and node_type columns.
# SOURCE nodes are always public since they're just categorical roots.
for source in DocumentSource:
source_name = (
source.name
) # e.g., 'GOOGLE_DRIVE' - what SQLAlchemy stores/expects
source_value = source.value # e.g., 'google_drive' - the raw_node_id
display_name = SOURCE_DISPLAY_NAMES.get(
source_value, source_value.replace("_", " ").title()
)
op.execute(
sa.text(
"""
INSERT INTO hierarchy_node (raw_node_id, display_name, source, node_type, parent_id, is_public)
VALUES (:raw_node_id, :display_name, :source, 'SOURCE', NULL, true)
ON CONFLICT (raw_node_id, source) DO NOTHING
"""
).bindparams(
raw_node_id=source_value, # Use .value for raw_node_id (human-readable identifier)
display_name=display_name,
source=source_name, # Use .name for source column (SQLAlchemy enum storage)
)
)
# 4. Add parent_hierarchy_node_id column to document table
op.add_column(
"document",
sa.Column("parent_hierarchy_node_id", sa.Integer(), nullable=True),
)
# When hierarchy node is deleted, just unlink the document (SET NULL)
op.create_foreign_key(
"fk_document_parent_hierarchy_node",
"document",
"hierarchy_node",
["parent_hierarchy_node_id"],
["id"],
ondelete="SET NULL",
)
op.create_index(
"ix_document_parent_hierarchy_node_id",
"document",
["parent_hierarchy_node_id"],
)
# 5. Set all existing documents' parent_hierarchy_node_id to their source's SOURCE node
# For documents with multiple connectors, we pick one source deterministically (MIN connector_id)
# NOTE: Both connector.source and hierarchy_node.source store enum NAMEs (e.g., 'GOOGLE_DRIVE')
# because SQLAlchemy Enum(native_enum=False) uses the enum name for storage.
op.execute(
sa.text(
"""
UPDATE document d
SET parent_hierarchy_node_id = hn.id
FROM (
-- Get the source for each document (pick MIN connector_id for determinism)
SELECT DISTINCT ON (dbcc.id)
dbcc.id as doc_id,
c.source as source
FROM document_by_connector_credential_pair dbcc
JOIN connector c ON dbcc.connector_id = c.id
ORDER BY dbcc.id, dbcc.connector_id
) doc_source
JOIN hierarchy_node hn ON hn.source = doc_source.source AND hn.node_type = 'SOURCE'
WHERE d.id = doc_source.doc_id
"""
)
)
# Create the persona__hierarchy_node association table
op.create_table(
"persona__hierarchy_node",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["hierarchy_node_id"],
["hierarchy_node.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("persona_id", "hierarchy_node_id"),
)
# Add index for efficient lookups
op.create_index(
"ix_persona__hierarchy_node_hierarchy_node_id",
"persona__hierarchy_node",
["hierarchy_node_id"],
)
# Create the persona__document association table for attaching individual
# documents directly to assistants
op.create_table(
"persona__document",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["document_id"],
["document.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("persona_id", "document_id"),
)
# Add index for efficient lookups by document_id
op.create_index(
"ix_persona__document_document_id",
"persona__document",
["document_id"],
)
# 6. Add last_time_hierarchy_fetch column to connector_credential_pair table
op.add_column(
"connector_credential_pair",
sa.Column(
"last_time_hierarchy_fetch", sa.DateTime(timezone=True), nullable=True
),
)
def downgrade() -> None:
# Remove last_time_hierarchy_fetch from connector_credential_pair
op.drop_column("connector_credential_pair", "last_time_hierarchy_fetch")
# Drop persona__document table
op.drop_index("ix_persona__document_document_id", table_name="persona__document")
op.drop_table("persona__document")
# Drop persona__hierarchy_node table
op.drop_index(
"ix_persona__hierarchy_node_hierarchy_node_id",
table_name="persona__hierarchy_node",
)
op.drop_table("persona__hierarchy_node")
# Remove parent_hierarchy_node_id from document
op.drop_index("ix_document_parent_hierarchy_node_id", table_name="document")
op.drop_constraint(
"fk_document_parent_hierarchy_node", "document", type_="foreignkey"
)
op.drop_column("document", "parent_hierarchy_node_id")
# Drop hierarchy_fetch_attempt table
op.drop_index(
"ix_hierarchy_fetch_attempt_cc_pair", table_name="hierarchy_fetch_attempt"
)
op.drop_index(
"ix_hierarchy_fetch_attempt_time_created", table_name="hierarchy_fetch_attempt"
)
op.drop_index(
"ix_hierarchy_fetch_attempt_status", table_name="hierarchy_fetch_attempt"
)
op.drop_table("hierarchy_fetch_attempt")
# Drop hierarchy_node table
op.drop_index("uq_hierarchy_node_one_source_per_type", table_name="hierarchy_node")
op.drop_index("ix_hierarchy_node_source_type", table_name="hierarchy_node")
op.drop_index("ix_hierarchy_node_parent_id", table_name="hierarchy_node")
op.drop_table("hierarchy_node")

View File

@@ -0,0 +1,31 @@
"""add chat_background to user
Revision ID: fb80bdd256de
Revises: 8b5ce697290e
Create Date: 2026-01-16 16:15:59.222617
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fb80bdd256de"
down_revision = "8b5ce697290e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"chat_background",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("user", "chat_background")

View File

@@ -122,6 +122,9 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
POSTHOG_DEBUG_LOGS_ENABLED = (
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"
)
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
@@ -133,3 +136,9 @@ GATED_TENANTS_KEY = "gated_tenants"
LICENSE_ENFORCEMENT_ENABLED = (
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
)
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
# Used when MULTI_TENANT=false (self-hosted mode)
CLOUD_DATA_PLANE_URL = os.environ.get(
"CLOUD_DATA_PLANE_URL", "https://cloud.onyx.app/api"
)

View File

@@ -0,0 +1,73 @@
"""Constants for license enforcement.
This file is the single source of truth for:
1. Paths that bypass license enforcement (always accessible)
2. Paths that require an EE license (EE-only features)
Import these constants in both production code and tests to ensure consistency.
"""
# Paths that are ALWAYS accessible, even when license is expired/gated.
# These enable users to:
# /auth - Log in/out (users can't fix billing if locked out of auth)
# /license - Fetch, upload, or check license status
# /health - Health checks for load balancers/orchestrators
# /me - Basic user info needed for UI rendering
# /settings, /enterprise-settings - View app status and branding
# /billing - Unified billing API
# /proxy - Self-hosted proxy endpoints (have own license-based auth)
# /tenants/billing-* - Legacy billing endpoints (backwards compatibility)
# /manage/users, /users - User management (needed for seat limit resolution)
# /notifications - Needed for UI to load properly
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES: frozenset[str] = frozenset(
{
"/auth",
"/license",
"/health",
"/me",
"/settings",
"/enterprise-settings",
# Billing endpoints (unified API for both MT and self-hosted)
"/billing",
"/admin/billing",
# Proxy endpoints for self-hosted billing (no tenant context)
"/proxy",
# Legacy tenant billing endpoints (kept for backwards compatibility)
"/tenants/billing-information",
"/tenants/create-customer-portal-session",
"/tenants/create-subscription-session",
# User management - needed to remove users when seat limit exceeded
"/manage/users",
"/manage/admin/users",
"/manage/admin/valid-domains",
"/manage/admin/deactivate-user",
"/manage/admin/delete-user",
"/users",
# Notifications - needed for UI to load properly
"/notifications",
}
)
# EE-only paths that require a valid license.
# Users without a license (community edition) cannot access these.
# These are blocked even when user has never subscribed (no license).
EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
{
# User groups and access control
"/manage/admin/user-group",
# Analytics and reporting
"/analytics",
# Query history (admin chat session endpoints)
"/admin/chat-sessions",
"/admin/chat-session-history",
"/admin/query-history",
# Usage reporting/export
"/admin/usage-report",
# Standard answers (canned responses)
"/manage/admin/standard-answer",
# Token rate limits
"/admin/token-rate-limits",
# Evals
"/evals",
}
)

View File

@@ -1,6 +1,7 @@
"""Database and cache operations for the license table."""
from datetime import datetime
from typing import NamedTuple
from sqlalchemy import func
from sqlalchemy import select
@@ -9,6 +10,7 @@ from sqlalchemy.orm import Session
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from onyx.auth.schemas import UserRole
from onyx.db.models import License
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
@@ -23,6 +25,13 @@ LICENSE_METADATA_KEY = "license:metadata"
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
class SeatAvailabilityResult(NamedTuple):
"""Result of a seat availability check."""
available: bool
error_message: str | None = None
# -----------------------------------------------------------------------------
# Database CRUD Operations
# -----------------------------------------------------------------------------
@@ -95,23 +104,30 @@ def delete_license(db_session: Session) -> bool:
def get_used_seats(tenant_id: str | None = None) -> int:
"""
Get current seat usage.
Get current seat usage directly from database.
For multi-tenant: counts users in UserTenantMapping for this tenant.
For self-hosted: counts all active users (includes both Onyx UI users
and Slack users who have been converted to Onyx users).
For self-hosted: counts all active users (excludes EXT_PERM_USER role).
TODO: Exclude API key dummy users from seat counting. API keys create
users with emails like `__DANSWER_API_KEY_*` that should not count toward
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
"""
if MULTI_TENANT:
from ee.onyx.server.tenants.user_mapping import get_tenant_count
return get_tenant_count(tenant_id or get_current_tenant_id())
else:
# Self-hosted: count all active users (Onyx + converted Slack users)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
result = db_session.execute(
select(func.count()).select_from(User).where(User.is_active) # type: ignore
select(func.count())
.select_from(User)
.where(
User.is_active == True, # type: ignore # noqa: E712
User.role != UserRole.EXT_PERM_USER,
)
)
return result.scalar() or 0
@@ -276,3 +292,43 @@ def get_license_metadata(
# Refresh from database
return refresh_license_cache(db_session, tenant_id)
def check_seat_availability(
db_session: Session,
seats_needed: int = 1,
tenant_id: str | None = None,
) -> SeatAvailabilityResult:
"""
Check if there are enough seats available to add users.
Args:
db_session: Database session
seats_needed: Number of seats needed (default 1)
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
SeatAvailabilityResult with available=True if seats are available,
or available=False with error_message if limit would be exceeded.
Returns available=True if no license exists (self-hosted = unlimited).
"""
metadata = get_license_metadata(db_session, tenant_id)
# No license = no enforcement (self-hosted without license)
if metadata is None:
return SeatAvailabilityResult(available=True)
# Calculate current usage directly from DB (not cache) for accuracy
current_used = get_used_seats(tenant_id)
total_seats = metadata.seats
# Use > (not >=) to allow filling to exactly 100% capacity
would_exceed_limit = current_used + seats_needed > total_seats
if would_exceed_limit:
return SeatAvailabilityResult(
available=False,
error_message=f"Seat limit would be exceeded: {current_used} of {total_seats} seats used, "
f"cannot add {seats_needed} more user(s).",
)
return SeatAvailabilityResult(available=True)

View File

@@ -7,6 +7,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
from onyx.access.models import DocExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -60,6 +61,9 @@ def gmail_doc_sync(
callback.progress("gmail_doc_sync", 1)
if isinstance(slim_doc, HierarchyNode):
# TODO: handle hierarchynodes during sync
continue
if slim_doc.external_access is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@@ -15,6 +15,7 @@ from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -195,7 +196,9 @@ def gdrive_doc_sync(
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
if isinstance(slim_doc, HierarchyNode):
# TODO: handle hierarchynodes during sync
continue
if slim_doc.external_access is None:
raise ValueError(
f"Drive perm sync: No external access for document {slim_doc.id}"

View File

@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.models import HierarchyNode
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call
from onyx.connectors.slack.connector import SlackConnector
@@ -111,6 +112,9 @@ def _get_slack_document_access(
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if isinstance(doc_metadata, HierarchyNode):
# TODO: handle hierarchynodes during sync
continue
if doc_metadata.external_access is None:
raise ValueError(
f"No external access for document {doc_metadata.id}. "

View File

@@ -5,6 +5,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -49,6 +50,9 @@ def generic_doc_sync(
callback.progress(label, 1)
for doc in doc_batch:
if isinstance(doc, HierarchyNode):
# TODO: handle hierarchynodes during sync
continue
if not doc.external_access:
raise RuntimeError(
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"

View File

@@ -4,8 +4,10 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
from ee.onyx.server.enterprise_settings.api import (
admin_router as enterprise_settings_admin_router,
@@ -85,10 +87,11 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_api_server_tenant_id_middleware(application, logger)
# Add license enforcement middleware (runs after tenant tracking)
# This blocks access when license is expired/gated
add_license_enforcement_middleware(application, logger)
else:
# License enforcement middleware for self-hosted deployments only
# Checks LICENSE_ENFORCEMENT_ENABLED at runtime (can be toggled without restart)
# MT deployments use control plane gating via is_tenant_gated() instead
add_license_enforcement_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
# For Google OAuth, refresh tokens are requested by:
@@ -148,6 +151,13 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)

View File

@@ -17,7 +17,8 @@ from onyx.context.search.models import InferenceChunk
from onyx.context.search.pipeline import merge_individual_chunks
from onyx.context.search.pipeline import search_pipeline
from onyx.db.models import User
from onyx.document_index.factory import get_current_primary_default_document_index
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import DocumentIndex
from onyx.llm.factory import get_default_llm
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
@@ -42,11 +43,13 @@ def _run_single_search(
document_index: DocumentIndex,
user: User | None,
db_session: Session,
num_hits: int | None = None,
) -> list[InferenceChunk]:
"""Execute a single search query and return chunks."""
chunk_search_request = ChunkSearchRequest(
query=query,
user_selected_filters=filters,
limit=num_hits,
)
return search_pipeline(
@@ -72,7 +75,9 @@ def stream_search_query(
Used by both streaming and non-streaming endpoints.
"""
# Get document index
document_index = get_current_primary_default_document_index(db_session)
search_settings = get_current_search_settings(db_session)
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None)
# Determine queries to execute
original_query = request.search_query
@@ -114,6 +119,7 @@ def stream_search_query(
document_index=document_index,
user=user,
db_session=db_session,
num_hits=request.num_hits,
)
else:
# Multiple queries - run in parallel and merge with RRF
@@ -121,7 +127,14 @@ def stream_search_query(
search_functions = [
(
_run_single_search,
(query, request.filters, document_index, user, db_session),
(
query,
request.filters,
document_index,
user,
db_session,
request.num_hits,
),
)
for query in all_executed_queries
]
@@ -168,6 +181,9 @@ def stream_search_query(
# Merge chunks into sections
sections = merge_individual_chunks(chunks)
# Truncate to the requested number of hits
sections = sections[: request.num_hits]
# Apply LLM document selection if requested
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it

View File

@@ -10,6 +10,16 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
("/enterprise-settings/logo", {"GET"}),
("/enterprise-settings/logotype", {"GET"}),
("/enterprise-settings/custom-analytics-script", {"GET"}),
# Stripe publishable key is safe to expose publicly
("/tenants/stripe-publishable-key", {"GET"}),
("/admin/billing/stripe-publishable-key", {"GET"}),
# Proxy endpoints use license-based auth, not user auth
("/proxy/create-checkout-session", {"POST"}),
("/proxy/claim-license", {"POST"}),
("/proxy/create-customer-portal-session", {"POST"}),
("/proxy/billing-information", {"GET"}),
("/proxy/license/{tenant_id}", {"GET"}),
("/proxy/seats/update", {"POST"}),
]

View File

@@ -0,0 +1,264 @@
"""Unified Billing API endpoints.
These endpoints provide Stripe billing functionality for both cloud and
self-hosted deployments. The service layer routes requests appropriately:
- Self-hosted: Routes through cloud data plane proxy
Flow: Backend /admin/billing/* → Cloud DP /proxy/* → Control plane
- Cloud (MULTI_TENANT): Routes directly to control plane
Flow: Backend /admin/billing/* → Control plane
License claiming is handled separately by /license/claim endpoint (self-hosted only).
Migration Note (ENG-3533):
This /admin/billing/* API replaces the older /tenants/* billing endpoints:
- /tenants/billing-information -> /admin/billing/billing-information
- /tenants/create-customer-portal-session -> /admin/billing/create-customer-portal-session
- /tenants/create-subscription-session -> /admin/billing/create-checkout-session
- /tenants/stripe-publishable-key -> /admin/billing/stripe-publishable-key
See: https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
"""
import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.db.license import get_license
from ee.onyx.server.billing.models import BillingInformationResponse
from ee.onyx.server.billing.models import CreateCheckoutSessionRequest
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
from ee.onyx.server.billing.models import CreateCustomerPortalSessionRequest
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
from ee.onyx.server.billing.models import SeatUpdateRequest
from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.billing.models import StripePublishableKeyResponse
from ee.onyx.server.billing.models import SubscriptionStatusResponse
from ee.onyx.server.billing.service import BillingServiceError
from ee.onyx.server.billing.service import (
create_checkout_session as create_checkout_service,
)
from ee.onyx.server.billing.service import (
create_customer_portal_session as create_portal_service,
)
from ee.onyx.server.billing.service import (
get_billing_information as get_billing_service,
)
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.engine.sql_engine import get_session
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/admin/billing")
# Cache for Stripe publishable key to avoid hitting S3 on every request
_stripe_publishable_key_cache: str | None = None
_stripe_key_lock = asyncio.Lock()
def _get_license_data(db_session: Session) -> str | None:
"""Get license data from database if exists (self-hosted only)."""
if MULTI_TENANT:
return None
license_record = get_license(db_session)
return license_record.license_data if license_record else None
def _get_tenant_id() -> str | None:
"""Get tenant ID for cloud deployments."""
if MULTI_TENANT:
return get_current_tenant_id()
return None
@router.post("/create-checkout-session")
async def create_checkout_session(
request: CreateCheckoutSessionRequest | None = None,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CreateCheckoutSessionResponse:
"""Create a Stripe checkout session for new subscription or renewal.
For new customers, no license/tenant is required.
For renewals, existing license (self-hosted) or tenant_id (cloud) is used.
After checkout completion:
- Self-hosted: Use /license/claim to retrieve the license
- Cloud: Subscription is automatically activated
"""
license_data = _get_license_data(db_session)
tenant_id = _get_tenant_id()
billing_period = request.billing_period if request else "monthly"
email = request.email if request else None
# Build redirect URL for after checkout completion
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
try:
return await create_checkout_service(
billing_period=billing_period,
email=email,
license_data=license_data,
redirect_url=redirect_url,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
request: CreateCustomerPortalSessionRequest | None = None,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CreateCustomerPortalSessionResponse:
"""Create a Stripe customer portal session for managing subscription.
Requires existing license (self-hosted) or active tenant (cloud).
"""
license_data = _get_license_data(db_session)
tenant_id = _get_tenant_id()
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise HTTPException(status_code=400, detail="No license found")
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
try:
return await create_portal_service(
license_data=license_data,
return_url=return_url,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.get("/billing-information")
async def get_billing_information(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> BillingInformationResponse | SubscriptionStatusResponse:
"""Get billing information for the current subscription.
Returns subscription status and details from Stripe.
"""
license_data = _get_license_data(db_session)
tenant_id = _get_tenant_id()
# Self-hosted without license = no subscription
if not MULTI_TENANT and not license_data:
return SubscriptionStatusResponse(subscribed=False)
try:
return await get_billing_service(
license_data=license_data,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.post("/seats/update")
async def update_seats(
request: SeatUpdateRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> SeatUpdateResponse:
"""Update the seat count for the current subscription.
Handles Stripe proration and license regeneration via control plane.
"""
license_data = _get_license_data(db_session)
tenant_id = _get_tenant_id()
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise HTTPException(status_code=400, detail="No license found")
try:
return await update_seat_service(
new_seat_count=request.new_seat_count,
license_data=license_data,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.get("/stripe-publishable-key")
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
"""Fetch the Stripe publishable key.
Priority: env var override (for testing) > S3 bucket (production).
This endpoint is public (no auth required) since publishable keys are safe to expose.
The key is cached in memory to avoid hitting S3 on every request.
"""
global _stripe_publishable_key_cache
# Fast path: return cached value without lock
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Use lock to prevent concurrent S3 requests
async with _stripe_key_lock:
# Double-check after acquiring lock (another request may have populated cache)
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Check for env var override first (for local testing with pk_test_* keys)
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
response.raise_for_status()
key = response.text.strip()
# Validate key format
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -0,0 +1,75 @@
"""Pydantic models for the billing API."""
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
class CreateCheckoutSessionRequest(BaseModel):
"""Request to create a Stripe checkout session."""
billing_period: Literal["monthly", "annual"] = "monthly"
email: str | None = None
class CreateCheckoutSessionResponse(BaseModel):
"""Response containing the Stripe checkout session URL."""
stripe_checkout_url: str
class CreateCustomerPortalSessionRequest(BaseModel):
"""Request to create a Stripe customer portal session."""
return_url: str | None = None
class CreateCustomerPortalSessionResponse(BaseModel):
"""Response containing the Stripe customer portal URL."""
stripe_customer_portal_url: str
class BillingInformationResponse(BaseModel):
"""Billing information for the current subscription."""
tenant_id: str
status: str | None = None
plan_type: str | None = None
seats: int | None = None
billing_period: str | None = None
current_period_start: datetime | None = None
current_period_end: datetime | None = None
cancel_at_period_end: bool = False
canceled_at: datetime | None = None
trial_start: datetime | None = None
trial_end: datetime | None = None
payment_method_enabled: bool = False
class SubscriptionStatusResponse(BaseModel):
"""Response when no subscription exists."""
subscribed: bool = False
class SeatUpdateRequest(BaseModel):
"""Request to update seat count."""
new_seat_count: int
class SeatUpdateResponse(BaseModel):
"""Response from seat update operation."""
success: bool
current_seats: int
used_seats: int
message: str | None = None
class StripePublishableKeyResponse(BaseModel):
"""Response containing the Stripe publishable key."""
publishable_key: str

View File

@@ -0,0 +1,267 @@
"""Service layer for billing operations.
This module provides functions for billing operations that route differently
based on deployment type:
- Self-hosted (not MULTI_TENANT): Routes through cloud data plane proxy
Flow: Self-hosted backend → Cloud DP /proxy/* → Control plane
- Cloud (MULTI_TENANT): Routes directly to control plane
Flow: Cloud backend → Control plane
"""
from typing import Literal
import httpx
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
from ee.onyx.server.billing.models import BillingInformationResponse
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.billing.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
# HTTP request timeout for billing service calls
_REQUEST_TIMEOUT = 30.0
class BillingServiceError(Exception):
"""Exception raised for billing service errors."""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(self.message)
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
"""Build headers for proxy requests (self-hosted).
Self-hosted instances authenticate with their license.
"""
headers = {"Content-Type": "application/json"}
if license_data:
headers["Authorization"] = f"Bearer {license_data}"
return headers
def _get_direct_headers() -> dict[str, str]:
"""Build headers for direct control plane requests (cloud).
Cloud instances authenticate with JWT.
"""
token = generate_data_plane_token()
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
def _get_base_url() -> str:
"""Get the base URL based on deployment type."""
if MULTI_TENANT:
return CONTROL_PLANE_API_BASE_URL
return f"{CLOUD_DATA_PLANE_URL}/proxy"
def _get_headers(license_data: str | None) -> dict[str, str]:
"""Get appropriate headers based on deployment type."""
if MULTI_TENANT:
return _get_direct_headers()
return _get_proxy_headers(license_data)
async def _make_billing_request(
method: Literal["GET", "POST"],
path: str,
license_data: str | None = None,
body: dict | None = None,
params: dict | None = None,
error_message: str = "Billing service request failed",
) -> dict:
"""Make an HTTP request to the billing service.
Consolidates the common HTTP request pattern used by all billing operations.
Args:
method: HTTP method (GET or POST)
path: URL path (appended to base URL)
license_data: License for authentication (self-hosted)
body: Request body for POST requests
params: Query parameters for GET requests
error_message: Default error message if request fails
Returns:
Response JSON as dict
Raises:
BillingServiceError: If request fails
"""
base_url = _get_base_url()
url = f"{base_url}{path}"
headers = _get_headers(license_data)
try:
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
else:
response = await client.post(url, headers=headers, json=body)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
detail = error_message
try:
error_data = e.response.json()
detail = error_data.get("detail", detail)
except Exception:
pass
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
raise BillingServiceError(detail, e.response.status_code)
except httpx.RequestError:
logger.exception("Failed to connect to billing service")
raise BillingServiceError("Failed to connect to billing service", 502)
async def create_checkout_session(
billing_period: str = "monthly",
email: str | None = None,
license_data: str | None = None,
redirect_url: str | None = None,
tenant_id: str | None = None,
) -> CreateCheckoutSessionResponse:
"""Create a Stripe checkout session.
Args:
billing_period: "monthly" or "annual"
email: Customer email for new subscriptions
license_data: Existing license for renewals (self-hosted)
redirect_url: URL to redirect after successful checkout
tenant_id: Tenant ID (cloud only, for renewals)
Returns:
CreateCheckoutSessionResponse with checkout URL
"""
body: dict = {"billing_period": billing_period}
if email:
body["email"] = email
if redirect_url:
body["redirect_url"] = redirect_url
if tenant_id and MULTI_TENANT:
body["tenant_id"] = tenant_id
data = await _make_billing_request(
method="POST",
path="/create-checkout-session",
license_data=license_data,
body=body,
error_message="Failed to create checkout session",
)
return CreateCheckoutSessionResponse(stripe_checkout_url=data["url"])
async def create_customer_portal_session(
license_data: str | None = None,
return_url: str | None = None,
tenant_id: str | None = None,
) -> CreateCustomerPortalSessionResponse:
"""Create a Stripe customer portal session.
Args:
license_data: License blob for authentication (self-hosted)
return_url: URL to return to after portal session
tenant_id: Tenant ID (cloud only)
Returns:
CreateCustomerPortalSessionResponse with portal URL
"""
body: dict = {}
if return_url:
body["return_url"] = return_url
if tenant_id and MULTI_TENANT:
body["tenant_id"] = tenant_id
data = await _make_billing_request(
method="POST",
path="/create-customer-portal-session",
license_data=license_data,
body=body,
error_message="Failed to create customer portal session",
)
return CreateCustomerPortalSessionResponse(stripe_customer_portal_url=data["url"])
async def get_billing_information(
license_data: str | None = None,
tenant_id: str | None = None,
) -> BillingInformationResponse | SubscriptionStatusResponse:
"""Fetch billing information.
Args:
license_data: License blob for authentication (self-hosted)
tenant_id: Tenant ID (cloud only)
Returns:
BillingInformationResponse or SubscriptionStatusResponse if no subscription
"""
params = {}
if tenant_id and MULTI_TENANT:
params["tenant_id"] = tenant_id
data = await _make_billing_request(
method="GET",
path="/billing-information",
license_data=license_data,
params=params or None,
error_message="Failed to fetch billing information",
)
# Check if no subscription
if isinstance(data, dict) and data.get("subscribed") is False:
return SubscriptionStatusResponse(subscribed=False)
return BillingInformationResponse(**data)
async def update_seat_count(
new_seat_count: int,
license_data: str | None = None,
tenant_id: str | None = None,
) -> SeatUpdateResponse:
"""Update the seat count for the current subscription.
Args:
new_seat_count: New number of seats
license_data: License blob for authentication (self-hosted)
tenant_id: Tenant ID (cloud only)
Returns:
SeatUpdateResponse with updated seat information
"""
body: dict = {"new_seat_count": new_seat_count}
if tenant_id and MULTI_TENANT:
body["tenant_id"] = tenant_id
data = await _make_billing_request(
method="POST",
path="/seats/update",
license_data=license_data,
body=body,
error_message="Failed to update seat count",
)
return SeatUpdateResponse(
success=data.get("success", False),
current_seats=data.get("current_seats", 0),
used_seats=data.get("used_seats", 0),
message=data.get("message"),
)

View File

@@ -1,4 +1,14 @@
"""License API endpoints."""
"""License API endpoints for self-hosted deployments.
These endpoints allow self-hosted Onyx instances to:
1. Claim a license after Stripe checkout (via cloud data plane proxy)
2. Upload a license file manually (for air-gapped deployments)
3. View license status and seat usage
4. Refresh/delete the local license
NOTE: Cloud (MULTI_TENANT) deployments do NOT use these endpoints.
Cloud licensing is managed via the control plane and gated_tenants Redis key.
"""
import requests
from fastapi import APIRouter
@@ -9,6 +19,7 @@ from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
from ee.onyx.db.license import delete_license as db_delete_license
from ee.onyx.db.license import get_license_metadata
from ee.onyx.db.license import invalidate_license_cache
@@ -20,13 +31,11 @@ from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.license.models import LicenseStatusResponse
from ee.onyx.server.license.models import LicenseUploadResponse
from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.users import User
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.db.engine.sql_engine import get_session
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -79,81 +88,80 @@ async def get_seat_usage(
)
@router.post("/fetch")
async def fetch_license(
@router.post("/claim")
async def claim_license(
session_id: str,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseResponse:
"""
Fetch license from control plane.
Used after Stripe checkout completion to retrieve the new license.
"""
tenant_id = get_current_tenant_id()
Claim a license after Stripe checkout (self-hosted only).
try:
token = generate_data_plane_token()
except ValueError as e:
logger.error(f"Failed to generate data plane token: {e}")
After a user completes Stripe checkout, they're redirected back with a
session_id. This endpoint exchanges that session_id for a signed license
via the cloud data plane proxy.
Flow:
1. Self-hosted frontend redirects to Stripe checkout (via cloud proxy)
2. User completes payment
3. Stripe redirects back to self-hosted instance with session_id
4. Frontend calls this endpoint with session_id
5. We call cloud data plane /proxy/claim-license to get the signed license
6. License is stored locally and cached
"""
if MULTI_TENANT:
raise HTTPException(
status_code=500, detail="Authentication configuration error"
status_code=400,
detail="License claiming is only available for self-hosted deployments",
)
try:
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
response = requests.get(url, headers=headers, timeout=10)
# Call cloud data plane to claim the license
url = f"{CLOUD_DATA_PLANE_URL}/proxy/claim-license"
response = requests.post(
url,
json={"session_id": session_id},
headers={"Content-Type": "application/json"},
timeout=30,
)
response.raise_for_status()
data = response.json()
if not isinstance(data, dict) or "license" not in data:
raise HTTPException(
status_code=502, detail="Invalid response from control plane"
)
license_data = data.get("license")
license_data = data["license"]
if not license_data:
raise HTTPException(status_code=404, detail="No license found")
raise HTTPException(status_code=404, detail="No license in response")
# Verify signature before persisting
payload = verify_license_signature(license_data)
# Verify the fetched license is for this tenant
if payload.tenant_id != tenant_id:
logger.error(
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
)
raise HTTPException(
status_code=400,
detail="License tenant ID mismatch - control plane returned wrong license",
)
# Persist to DB and update cache atomically
# Store in DB
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
logger.info(
f"License claimed: seats={payload.seats}, expires={payload.expires_at.date()}"
)
return LicenseResponse(success=True, license=payload)
except requests.HTTPError as e:
status_code = e.response.status_code if e.response is not None else 502
logger.error(f"Control plane returned error: {status_code}")
raise HTTPException(
status_code=status_code,
detail="Failed to fetch license from control plane",
)
detail = "Failed to claim license"
try:
error_data = e.response.json() if e.response is not None else {}
detail = error_data.get("detail", detail)
except Exception:
pass
raise HTTPException(status_code=status_code, detail=detail)
except ValueError as e:
logger.error(f"License verification failed: {type(e).__name__}")
raise HTTPException(status_code=400, detail=str(e))
except requests.RequestException:
logger.exception("Failed to fetch license from control plane")
raise HTTPException(
status_code=502, detail="Failed to connect to control plane"
status_code=502, detail="Failed to connect to license server"
)
@@ -164,33 +172,36 @@ async def upload_license(
db_session: Session = Depends(get_session),
) -> LicenseUploadResponse:
"""
Upload a license file manually.
Used for air-gapped deployments where control plane is not accessible.
Upload a license file manually (self-hosted only).
Used for air-gapped deployments where the cloud data plane is not accessible.
The license file must be cryptographically signed by Onyx.
"""
if MULTI_TENANT:
raise HTTPException(
status_code=400,
detail="License upload is only available for self-hosted deployments",
)
try:
content = await license_file.read()
license_data = content.decode("utf-8").strip()
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="Invalid license file format")
# Verify cryptographic signature - this is the only validation needed
# The license's tenant_id identifies the customer in control plane, not locally
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
tenant_id = get_current_tenant_id()
if payload.tenant_id != tenant_id:
raise HTTPException(
status_code=400,
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
)
# Persist to DB and update cache
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
return LicenseUploadResponse(
@@ -205,8 +216,10 @@ async def refresh_license_cache_endpoint(
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""
Force refresh the license cache from the database.
Force refresh the license cache from the local database.
Useful after manual database changes or to verify license validity.
Does NOT fetch from control plane - use /claim for that.
"""
metadata = refresh_license_cache(db_session)
@@ -233,9 +246,15 @@ async def delete_license(
) -> dict[str, bool]:
"""
Delete the current license.
Admin only - removes license and invalidates cache.
Admin only - removes license from database and invalidates cache.
"""
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
if MULTI_TENANT:
raise HTTPException(
status_code=400,
detail="License deletion is only available for self-hosted deployments",
)
try:
invalidate_license_cache()
except Exception as cache_error:

View File

@@ -1,4 +1,42 @@
"""Middleware to enforce license status application-wide."""
"""Middleware to enforce license status for SELF-HOSTED deployments only.
NOTE: This middleware is NOT used for multi-tenant (cloud) deployments.
Multi-tenant gating is handled separately by the control plane via the
/tenants/product-gating endpoint and is_tenant_gated() checks.
IMPORTANT: Mutual Exclusivity with ENTERPRISE_EDITION_ENABLED
============================================================
This middleware is controlled by LICENSE_ENFORCEMENT_ENABLED env var.
It works alongside the legacy ENTERPRISE_EDITION_ENABLED system:
- LICENSE_ENFORCEMENT_ENABLED=false (default):
Middleware is disabled. EE features are controlled solely by
ENTERPRISE_EDITION_ENABLED. This preserves legacy behavior.
- LICENSE_ENFORCEMENT_ENABLED=true:
Middleware actively enforces license status. EE features require
a valid license, regardless of ENTERPRISE_EDITION_ENABLED.
Eventually, ENTERPRISE_EDITION_ENABLED will be removed and license
enforcement will be the only mechanism for gating EE features.
License Enforcement States (when enabled)
=========================================
For self-hosted deployments:
1. No license (never subscribed):
- Allow community features (basic connectors, search, chat)
- Block EE-only features (analytics, user groups, etc.)
2. GATED_ACCESS (fully expired):
- Block all routes except billing/auth/license
- User must renew subscription to continue
3. Valid license (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER):
- Full access to all EE features
- Seat limits enforced
- GRACE_PERIOD/PAYMENT_REMINDER are for notifications only, not blocking
"""
import logging
from collections.abc import Awaitable
@@ -9,38 +47,30 @@ from fastapi import Request
from fastapi import Response
from fastapi.responses import JSONResponse
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.configs.license_enforcement_config import EE_ONLY_PATH_PREFIXES
from ee.onyx.configs.license_enforcement_config import (
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES,
)
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.server.tenants.product_gating import is_tenant_gated
from ee.onyx.db.license import refresh_license_cache
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
# Paths that are ALWAYS accessible, even when license is expired/gated.
# These enable users to:
# /auth - Log in/out (users can't fix billing if locked out of auth)
# /license - Fetch, upload, or check license status
# /health - Health checks for load balancers/orchestrators
# /me - Basic user info needed for UI rendering
# /settings, /enterprise-settings - View app status and branding
# /tenants/billing-* - Manage subscription to resolve gating
ALLOWED_PATH_PREFIXES = {
"/auth",
"/license",
"/health",
"/me",
"/settings",
"/enterprise-settings",
"/tenants/billing-information",
"/tenants/create-customer-portal-session",
"/tenants/create-subscription-session",
}
def _is_path_allowed(path: str) -> bool:
"""Check if path is in allowlist (prefix match)."""
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
return any(
path.startswith(prefix) for prefix in LICENSE_ENFORCEMENT_ALLOWED_PREFIXES
)
def _is_ee_only_path(path: str) -> bool:
"""Check if path requires EE license (prefix match)."""
return any(path.startswith(prefix) for prefix in EE_ONLY_PATH_PREFIXES)
def add_license_enforcement_middleware(
@@ -66,29 +96,84 @@ def add_license_enforcement_middleware(
is_gated = False
tenant_id = get_current_tenant_id()
if MULTI_TENANT:
try:
is_gated = is_tenant_gated(tenant_id)
except RedisError as e:
logger.warning(f"Failed to check tenant gating status: {e}")
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
else:
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata:
if metadata.status == ApplicationStatus.GATED_ACCESS:
is_gated = True
else:
# No license metadata = gated for self-hosted EE
try:
metadata = get_cached_license_metadata(tenant_id)
# If no cached metadata, check database (cache may have been cleared)
if not metadata:
logger.debug(
"[license_enforcement] No cached license, checking database..."
)
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
if metadata:
logger.info(
"[license_enforcement] Loaded license from database"
)
except SQLAlchemyError as db_error:
logger.warning(
f"[license_enforcement] Failed to check database for license: {db_error}"
)
if metadata:
# User HAS a license (current or expired)
if metadata.status == ApplicationStatus.GATED_ACCESS:
# License fully expired - gate the user
# Note: GRACE_PERIOD and PAYMENT_REMINDER are for notifications only,
# they don't block access
is_gated = True
except RedisError as e:
logger.warning(f"Failed to check license metadata: {e}")
# Fail open - don't block users due to Redis connectivity issues
else:
# License is active - check seat limit
# used_seats in cache is kept accurate via invalidation
# when users are added/removed
if metadata.used_seats > metadata.seats:
logger.info(
f"[license_enforcement] Blocking request: "
f"seat limit exceeded ({metadata.used_seats}/{metadata.seats})"
)
return JSONResponse(
status_code=402,
content={
"detail": {
"error": "seat_limit_exceeded",
"message": f"Seat limit exceeded: {metadata.used_seats} of {metadata.seats} seats used.",
"used_seats": metadata.used_seats,
"seats": metadata.seats,
}
},
)
else:
# No license in cache OR database = never subscribed
# Allow community features, but block EE-only features
if _is_ee_only_path(path):
logger.info(
f"[license_enforcement] Blocking EE-only path (no license): {path}"
)
return JSONResponse(
status_code=402,
content={
"detail": {
"error": "enterprise_license_required",
"message": "This feature requires an Enterprise license. "
"Please upgrade to access this functionality.",
}
},
)
logger.debug(
"[license_enforcement] No license, allowing community features"
)
is_gated = False
except RedisError as e:
logger.warning(f"Failed to check license metadata: {e}")
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
if is_gated:
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
logger.info(
f"[license_enforcement] Blocking request (license expired): {path}"
)
return JSONResponse(
status_code=402,
content={

View File

@@ -32,6 +32,7 @@ class SendSearchQueryRequest(BaseModel):
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 50
include_content: bool = False
stream: bool = False

View File

@@ -12,21 +12,51 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Statuses that indicate a billing/license problem - propagate these to settings
_GATED_STATUSES = frozenset(
{
ApplicationStatus.GATED_ACCESS,
ApplicationStatus.GRACE_PERIOD,
ApplicationStatus.PAYMENT_REMINDER,
}
)
# Only GATED_ACCESS actually blocks access - other statuses are for notifications
_BLOCKING_STATUS = ApplicationStatus.GATED_ACCESS
def check_ee_features_enabled() -> bool:
"""EE version: checks if EE features should be available.
Returns True if:
- LICENSE_ENFORCEMENT_ENABLED is False (legacy/rollout mode)
- Cloud mode (MULTI_TENANT) - cloud handles its own gating
- Self-hosted with a valid (non-expired) license
Returns False if:
- Self-hosted with no license (never subscribed)
- Self-hosted with expired license
"""
if not LICENSE_ENFORCEMENT_ENABLED:
# License enforcement disabled - allow EE features (legacy behavior)
return True
if MULTI_TENANT:
# Cloud mode - EE features always available (gating handled by is_tenant_gated)
return True
# Self-hosted with enforcement - check for valid license
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata and metadata.status != _BLOCKING_STATUS:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
return True
except RedisError as e:
logger.warning(f"Failed to check license for EE features: {e}")
# Fail closed - if Redis is down, other things will break anyway
return False
# No license or GATED_ACCESS - no EE features
return False
def apply_license_status_to_settings(settings: Settings) -> Settings:
"""EE version: checks license status for self-hosted deployments.
For self-hosted, looks up license metadata and overrides application_status
if the license is missing or indicates a problem (expired, grace period, etc.).
if the license indicates GATED_ACCESS (fully expired).
For multi-tenant (cloud), the settings already have the correct status
from the control plane, so no override is needed.
@@ -43,11 +73,10 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata and metadata.status in _GATED_STATUSES:
if metadata and metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
elif not metadata:
# No license = gated access for self-hosted EE
settings.application_status = ApplicationStatus.GATED_ACCESS
# No license = user hasn't purchased yet, allow access for upgrade flow
# GRACE_PERIOD/PAYMENT_REMINDER don't block - they're for notifications
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")

View File

@@ -3,6 +3,7 @@ from fastapi import APIRouter
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.proxy import router as proxy_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
@@ -22,3 +23,4 @@ router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
router.include_router(proxy_router)

View File

@@ -1,3 +1,24 @@
"""Billing API endpoints for cloud multi-tenant deployments.
DEPRECATED: These /tenants/* billing endpoints are being replaced by /admin/billing/*
which provides a unified API for both self-hosted and cloud deployments.
TODO(ENG-3533): Migrate frontend to use /admin/billing/* endpoints and remove this file.
https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
Current endpoints to migrate:
- GET /tenants/billing-information -> GET /admin/billing/information
- POST /tenants/create-customer-portal-session -> POST /admin/billing/portal-session
- POST /tenants/create-subscription-session -> POST /admin/billing/checkout-session
- GET /tenants/stripe-publishable-key -> (keep as-is, shared endpoint)
Note: /tenants/product-gating/* endpoints are control-plane-to-data-plane calls
and are NOT part of this migration - they stay here.
"""
import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
@@ -12,11 +33,14 @@ from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -26,6 +50,10 @@ logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Cache for Stripe publishable key to avoid hitting S3 on every request
_stripe_publishable_key_cache: str | None = None
_stripe_key_lock = asyncio.Lock()
@router.post("/product-gating")
def gate_product(
@@ -80,11 +108,7 @@ async def billing_information(
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
"""
Create a Stripe customer portal session via the control plane.
NOTE: This is currently only used for multi-tenant (cloud) deployments.
Self-hosted proxy endpoints will be added in a future phase.
"""
"""Create a Stripe customer portal session via the control plane."""
tenant_id = get_current_tenant_id()
return_url = f"{WEB_DOMAIN}/admin/billing"
@@ -113,3 +137,67 @@ async def create_subscription_session(
except Exception as e:
logger.exception("Failed to create subscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stripe-publishable-key")
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
"""
Fetch the Stripe publishable key.
Priority: env var override (for testing) > S3 bucket (production).
This endpoint is public (no auth required) since publishable keys are safe to expose.
The key is cached in memory to avoid hitting S3 on every request.
"""
global _stripe_publishable_key_cache
# Fast path: return cached value without lock
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Use lock to prevent concurrent S3 requests
async with _stripe_key_lock:
# Double-check after acquiring lock (another request may have populated cache)
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Check for env var override first (for local testing with pk_test_* keys)
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
response.raise_for_status()
key = response.text.strip()
# Validate key format
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -105,3 +105,7 @@ class PendingUserSnapshot(BaseModel):
class ApproveUserRequest(BaseModel):
email: str
class StripePublishableKeyResponse(BaseModel):
publishable_key: str

View File

@@ -0,0 +1,485 @@
"""Proxy endpoints for billing operations.
These endpoints run on the CLOUD DATA PLANE (cloud.onyx.app) and serve as a proxy
for self-hosted instances to reach the control plane.
Flow:
Self-hosted backend → Cloud DP /proxy/* (license auth) → Control plane (JWT auth)
Self-hosted instances call these endpoints with their license in the Authorization
header. The cloud data plane validates the license signature and forwards the
request to the control plane using JWT authentication.
Auth levels by endpoint:
- /create-checkout-session: No auth (new customer) or expired license OK (renewal)
- /claim-license: Session ID based (one-time after Stripe payment)
- /create-customer-portal-session: Expired license OK (need portal to fix payment)
- /billing-information: Valid license required
- /license/{tenant_id}: Valid license required
- /seats/update: Valid license required
"""
from typing import Literal
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Header
from fastapi import HTTPException
from pydantic import BaseModel
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import update_license_cache
from ee.onyx.db.license import upsert_license
from ee.onyx.server.billing.models import SeatUpdateRequest
from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.utils.license import is_license_valid
from ee.onyx.utils.license import verify_license_signature
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/proxy")
def _check_license_enforcement_enabled() -> None:
"""Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP)."""
if not LICENSE_ENFORCEMENT_ENABLED:
raise HTTPException(
status_code=501,
detail="Proxy endpoints are only available on cloud data plane",
)
def _extract_license_from_header(
authorization: str | None,
required: bool = True,
) -> str | None:
"""Extract license data from Authorization header.
Self-hosted instances authenticate to these proxy endpoints by sending their
license as a Bearer token: `Authorization: Bearer <base64-encoded-license>`.
We use the Bearer scheme (RFC 6750) because:
1. It's the standard HTTP auth scheme for token-based authentication
2. The license blob is cryptographically signed (RSA), so it's self-validating
3. No other auth schemes (Basic, Digest, etc.) are supported for license auth
The license data is the base64-encoded signed blob that contains tenant_id,
seats, expiration, etc. We verify the signature to authenticate the caller.
Args:
authorization: The Authorization header value (e.g., "Bearer <license>")
required: If True, raise 401 when header is missing/invalid
Returns:
License data string (base64-encoded), or None if not required and missing
Raises:
HTTPException: 401 if required and header is missing/invalid
"""
if not authorization or not authorization.startswith("Bearer "):
if required:
raise HTTPException(
status_code=401, detail="Missing or invalid authorization header"
)
return None
return authorization.split(" ", 1)[1]
def verify_license_auth(
license_data: str,
allow_expired: bool = False,
) -> LicensePayload:
"""Verify license signature and optionally check expiry.
Args:
license_data: Base64-encoded signed license blob
allow_expired: If True, accept expired licenses (for renewal flows)
Returns:
LicensePayload if valid
Raises:
HTTPException: If license is invalid or expired (when not allowed)
"""
_check_license_enforcement_enabled()
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise HTTPException(status_code=401, detail=f"Invalid license: {e}")
if not allow_expired and not is_license_valid(payload):
raise HTTPException(status_code=401, detail="License has expired")
return payload
async def get_license_payload(
authorization: str | None = Header(None, alias="Authorization"),
) -> LicensePayload:
"""Dependency: Require valid (non-expired) license.
Used for endpoints that require an active subscription.
"""
license_data = _extract_license_from_header(authorization, required=True)
# license_data is guaranteed non-None when required=True
assert license_data is not None
return verify_license_auth(license_data, allow_expired=False)
async def get_license_payload_allow_expired(
authorization: str | None = Header(None, alias="Authorization"),
) -> LicensePayload:
"""Dependency: Require license with valid signature, expired OK.
Used for endpoints needed to fix payment issues (portal, renewal checkout).
"""
license_data = _extract_license_from_header(authorization, required=True)
# license_data is guaranteed non-None when required=True
assert license_data is not None
return verify_license_auth(license_data, allow_expired=True)
async def get_optional_license_payload(
authorization: str | None = Header(None, alias="Authorization"),
) -> LicensePayload | None:
"""Dependency: Optional license auth (for checkout - new customers have none).
Returns None if no license provided, otherwise validates and returns payload.
Expired licenses are allowed for renewal flows.
"""
_check_license_enforcement_enabled()
license_data = _extract_license_from_header(authorization, required=False)
if license_data is None:
return None
return verify_license_auth(license_data, allow_expired=True)
async def forward_to_control_plane(
method: str,
path: str,
body: dict | None = None,
params: dict | None = None,
) -> dict:
"""Forward a request to the control plane with proper authentication."""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
elif method == "POST":
response = await client.post(url, headers=headers, json=body)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
detail = "Control plane request failed"
try:
error_data = e.response.json()
detail = error_data.get("detail", detail)
except Exception:
pass
logger.error(f"Control plane returned {status_code}: {detail}")
raise HTTPException(status_code=status_code, detail=detail)
except httpx.RequestError:
logger.exception("Failed to connect to control plane")
raise HTTPException(
status_code=502, detail="Failed to connect to control plane"
)
def fetch_and_store_license(tenant_id: str, license_data: str) -> None:
"""Store license in database and update Redis cache.
Args:
tenant_id: The tenant ID
license_data: Base64-encoded signed license blob
"""
try:
# Verify before storing
payload = verify_license_signature(license_data)
# Store in database using the specific tenant's schema
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
upsert_license(db_session, license_data)
# Update Redis cache
update_license_cache(
payload,
source=LicenseSource.AUTO_FETCH,
tenant_id=tenant_id,
)
except ValueError as e:
logger.error(f"Failed to verify license: {e}")
raise
except Exception:
logger.exception("Failed to store license")
raise
# -----------------------------------------------------------------------------
# Endpoints
# -----------------------------------------------------------------------------
class CreateCheckoutSessionRequest(BaseModel):
billing_period: Literal["monthly", "annual"] = "monthly"
email: str | None = None
# Redirect URL after successful checkout - self-hosted passes their instance URL
redirect_url: str | None = None
# Cancel URL when user exits checkout - returns to upgrade page
cancel_url: str | None = None
class CreateCheckoutSessionResponse(BaseModel):
url: str
@router.post("/create-checkout-session")
async def proxy_create_checkout_session(
request_body: CreateCheckoutSessionRequest,
license_payload: LicensePayload | None = Depends(get_optional_license_payload),
) -> CreateCheckoutSessionResponse:
"""Proxy checkout session creation to control plane.
Auth: Optional license (new customers don't have one yet).
If license provided, expired is OK (for renewals).
"""
# license_payload is None for new customers who don't have a license yet.
# In that case, tenant_id is omitted from the request body and the control
# plane will create a new tenant during checkout completion.
tenant_id = license_payload.tenant_id if license_payload else None
body: dict = {
"billing_period": request_body.billing_period,
}
if tenant_id:
body["tenant_id"] = tenant_id
if request_body.email:
body["email"] = request_body.email
if request_body.redirect_url:
body["redirect_url"] = request_body.redirect_url
if request_body.cancel_url:
body["cancel_url"] = request_body.cancel_url
result = await forward_to_control_plane(
"POST", "/create-checkout-session", body=body
)
return CreateCheckoutSessionResponse(url=result["url"])
class ClaimLicenseRequest(BaseModel):
session_id: str
class ClaimLicenseResponse(BaseModel):
tenant_id: str
license: str
message: str | None = None
@router.post("/claim-license")
async def proxy_claim_license(
request_body: ClaimLicenseRequest,
) -> ClaimLicenseResponse:
"""Claim a license after successful Stripe checkout.
Auth: Session ID based (one-time use after payment).
The control plane verifies the session_id is valid and unclaimed.
Returns the license to the caller. For self-hosted instances, they will
store the license locally. The cloud DP doesn't need to store it.
"""
_check_license_enforcement_enabled()
result = await forward_to_control_plane(
"POST",
"/claim-license",
body={"session_id": request_body.session_id},
)
tenant_id = result.get("tenant_id")
license_data = result.get("license")
if not tenant_id or not license_data:
logger.error(f"Control plane returned incomplete claim response: {result}")
raise HTTPException(
status_code=502,
detail="Control plane returned incomplete license data",
)
return ClaimLicenseResponse(
tenant_id=tenant_id,
license=license_data,
message="License claimed successfully",
)
class CreateCustomerPortalSessionRequest(BaseModel):
return_url: str | None = None
class CreateCustomerPortalSessionResponse(BaseModel):
url: str
@router.post("/create-customer-portal-session")
async def proxy_create_customer_portal_session(
request_body: CreateCustomerPortalSessionRequest | None = None,
license_payload: LicensePayload = Depends(get_license_payload_allow_expired),
) -> CreateCustomerPortalSessionResponse:
"""Proxy customer portal session creation to control plane.
Auth: License required, expired OK (need portal to fix payment issues).
"""
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id
body: dict = {"tenant_id": tenant_id}
if request_body and request_body.return_url:
body["return_url"] = request_body.return_url
result = await forward_to_control_plane(
"POST", "/create-customer-portal-session", body=body
)
return CreateCustomerPortalSessionResponse(url=result["url"])
class BillingInformationResponse(BaseModel):
tenant_id: str
status: str | None = None
plan_type: str | None = None
seats: int | None = None
billing_period: str | None = None
current_period_start: str | None = None
current_period_end: str | None = None
cancel_at_period_end: bool = False
canceled_at: str | None = None
trial_start: str | None = None
trial_end: str | None = None
payment_method_enabled: bool = False
stripe_subscription_id: str | None = None
@router.get("/billing-information")
async def proxy_billing_information(
license_payload: LicensePayload = Depends(get_license_payload),
) -> BillingInformationResponse:
"""Proxy billing information request to control plane.
Auth: Valid (non-expired) license required.
"""
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id
result = await forward_to_control_plane(
"GET", "/billing-information", params={"tenant_id": tenant_id}
)
# Add tenant_id from license if not in response (control plane may not include it)
if "tenant_id" not in result:
result["tenant_id"] = tenant_id
return BillingInformationResponse(**result)
class LicenseFetchResponse(BaseModel):
license: str
tenant_id: str
@router.get("/license/{tenant_id}")
async def proxy_license_fetch(
tenant_id: str,
license_payload: LicensePayload = Depends(get_license_payload),
) -> LicenseFetchResponse:
"""Proxy license fetch to control plane.
Auth: Valid license required.
The tenant_id in path must match the authenticated tenant.
"""
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise HTTPException(status_code=401, detail="License missing tenant_id")
if tenant_id != license_payload.tenant_id:
raise HTTPException(
status_code=403,
detail="Cannot fetch license for a different tenant",
)
result = await forward_to_control_plane("GET", f"/license/{tenant_id}")
# Auto-store the refreshed license
license_data = result.get("license")
if not license_data:
logger.error(f"Control plane returned incomplete license response: {result}")
raise HTTPException(
status_code=502,
detail="Control plane returned incomplete license data",
)
fetch_and_store_license(tenant_id, license_data)
return LicenseFetchResponse(license=license_data, tenant_id=tenant_id)
@router.post("/seats/update")
async def proxy_seat_update(
request_body: SeatUpdateRequest,
license_payload: LicensePayload = Depends(get_license_payload),
) -> SeatUpdateResponse:
"""Proxy seat update to control plane.
Auth: Valid (non-expired) license required.
Handles Stripe proration and license regeneration.
"""
if not license_payload.tenant_id:
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id
result = await forward_to_control_plane(
"POST",
"/seats/update",
body={
"tenant_id": tenant_id,
"new_seat_count": request_body.new_seat_count,
},
)
return SeatUpdateResponse(
success=result.get("success", False),
current_seats=result.get("current_seats", 0),
used_seats=result.get("used_seats", 0),
message=result.get("message"),
)

View File

@@ -1,6 +1,7 @@
from fastapi_users import exceptions
from sqlalchemy import select
from ee.onyx.db.license import invalidate_license_cache
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
@@ -47,6 +48,8 @@ def get_tenant_id_for_email(email: str) -> str:
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
@@ -70,49 +73,104 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
If a user already has an active mapping to a different tenant, they receive
an inactive mapping (invitation) to this tenant. They can accept the
invitation later to switch tenants.
Raises:
HTTPException: 402 if adding active users would exceed seat limit
"""
from fastapi import HTTPException
from ee.onyx.db.license import check_seat_availability
from onyx.db.engine.sql_engine import get_session_with_tenant as get_tenant_session
unique_emails = set(emails)
if not unique_emails:
return
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
# Start a transaction
db_session.begin()
for email in emails:
# Check if the user already has a mapping to this tenant
existing_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.first()
# Batch query 1: Get all existing mappings for these emails to this tenant
# Lock rows to prevent concurrent modifications
existing_mappings = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(unique_emails),
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.all()
)
emails_with_mapping = {m.email for m in existing_mappings}
# If user already has an active mapping, add this one as inactive
if not existing_mapping:
# Check if the user already has an active mapping to any tenant
has_active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# Batch query 2: Get all active mappings for these emails (any tenant)
active_mappings = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(unique_emails),
UserTenantMapping.active == True, # noqa: E712
)
.all()
)
emails_with_active_mapping = {m.email for m in active_mappings}
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=False if has_active_mapping else True,
)
# Determine which users will consume a new seat.
# Users with active mappings elsewhere get INACTIVE mappings (invitations)
# and don't consume seats until they accept. Only users without any active
# mapping will get an ACTIVE mapping and consume a seat immediately.
emails_consuming_seats = {
email
for email in unique_emails
if email not in emails_with_mapping
and email not in emails_with_active_mapping
}
# Check seat availability inside the transaction to prevent race conditions.
# Note: ALL users in unique_emails still get added below - this check only
# validates we have capacity for users who will consume seats immediately.
if emails_consuming_seats:
with get_tenant_session(tenant_id=tenant_id) as tenant_session:
result = check_seat_availability(
tenant_session,
seats_needed=len(emails_consuming_seats),
tenant_id=tenant_id,
)
if not result.available:
raise HTTPException(
status_code=402,
detail=result.error_message or "Seat limit exceeded",
)
# Add mappings for emails that don't already have one to this tenant
for email in unique_emails:
if email in emails_with_mapping:
continue
# Create mapping: inactive if user belongs to another tenant (invitation),
# active otherwise
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=email not in emails_with_active_mapping,
)
)
# Commit the transaction
db_session.commit()
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
except HTTPException:
db_session.rollback()
raise
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.rollback()
@@ -135,6 +193,9 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
db_session.delete(mapping)
db_session.commit()
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
@@ -149,6 +210,9 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
).delete()
db_session.commit()
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -177,6 +241,9 @@ def approve_user_invite(email: str, tenant_id: str) -> None:
db_session.add(new_mapping)
db_session.commit()
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
@@ -195,19 +262,42 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
Raises:
HTTPException: 402 if accepting would exceed seat limit
"""
from fastapi import HTTPException
from ee.onyx.db.license import check_seat_availability
from onyx.db.engine.sql_engine import get_session_with_tenant
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
# Lock the user's mappings first to prevent race conditions.
# This ensures no concurrent request can modify this user's mappings
# while we check seats and activate.
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.with_for_update()
.first()
)
# Check seat availability within the same logical operation.
# Note: This queries fresh data from DB, not cache.
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
result = check_seat_availability(
tenant_session, seats_needed=1, tenant_id=tenant_id
)
if not result.available:
raise HTTPException(
status_code=402,
detail=result.error_message or "Seat limit exceeded",
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
@@ -237,6 +327,9 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
@@ -297,16 +390,41 @@ def deny_user_invite(email: str, tenant_id: str) -> None:
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
Get the number of active users for this tenant.
A user counts toward the seat count if:
1. They have an active mapping to this tenant (UserTenantMapping.active == True)
2. AND the User is active (User.is_active == True)
TODO: Exclude API key dummy users from seat counting. API keys create
users with emails like `__DANSWER_API_KEY_*` that should not count toward
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
"""
from onyx.db.models import User
# First get all emails with active mappings to this tenant
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
active_mapping_emails = (
db_session.query(UserTenantMapping.email)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.all()
)
emails = [email for (email,) in active_mapping_emails]
if not emails:
return 0
# Now count how many of those users are actually active in the tenant's User table
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
user_count = (
db_session.query(User)
.filter(
User.email.in_(emails), # type: ignore
User.is_active == True, # type: ignore # noqa: E712
)
.count()
)

View File

@@ -5,6 +5,7 @@ import json
import os
from datetime import datetime
from datetime import timezone
from pathlib import Path
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
@@ -19,21 +20,27 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# RSA-4096 Public Key for license verification
# Load from environment variable - key is generated on the control plane
# In production, inject via Kubernetes secrets or secrets manager
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
# Path to the license public key file
_LICENSE_PUBLIC_KEY_PATH = (
Path(__file__).parent.parent.parent.parent / "keys" / "license_public_key.pem"
)
def _get_public_key() -> RSAPublicKey:
"""Load the public key from environment variable."""
if not LICENSE_PUBLIC_KEY_PEM:
raise ValueError(
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
"License verification requires the control plane public key."
)
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
"""Load the public key from file, with env var override."""
# Allow env var override for flexibility
key_pem = os.environ.get("LICENSE_PUBLIC_KEY_PEM")
if not key_pem:
# Read from file
if not _LICENSE_PUBLIC_KEY_PATH.exists():
raise ValueError(
f"License public key not found at {_LICENSE_PUBLIC_KEY_PATH}. "
"License verification requires the control plane public key."
)
key_pem = _LICENSE_PUBLIC_KEY_PATH.read_text()
key = serialization.load_pem_public_key(key_pem.encode())
if not isinstance(key, RSAPublicKey):
raise ValueError("Expected RSA public key")
return key
@@ -53,17 +60,21 @@ def verify_license_signature(license_data: str) -> LicensePayload:
ValueError: If license data is invalid or signature verification fails
"""
try:
# Decode the license data
decoded = json.loads(base64.b64decode(license_data))
# Parse into LicenseData to validate structure
license_obj = LicenseData(**decoded)
payload_json = json.dumps(
license_obj.payload.model_dump(mode="json"), sort_keys=True
)
# IMPORTANT: Use the ORIGINAL payload JSON for signature verification,
# not re-serialized through Pydantic. Pydantic may format fields differently
# (e.g., datetime "+00:00" vs "Z") which would break signature verification.
original_payload = decoded.get("payload", {})
payload_json = json.dumps(original_payload, sort_keys=True)
signature_bytes = base64.b64decode(license_obj.signature)
# Verify signature using PSS padding (modern standard)
public_key = _get_public_key()
public_key.verify(
signature_bytes,
payload_json.encode(),
@@ -77,16 +88,18 @@ def verify_license_signature(license_data: str) -> LicensePayload:
return license_obj.payload
except InvalidSignature:
logger.error("License signature verification failed")
logger.error("[verify_license] FAILED: Signature verification failed")
raise ValueError("Invalid license signature")
except json.JSONDecodeError:
logger.error("Failed to decode license JSON")
except json.JSONDecodeError as e:
logger.error(f"[verify_license] FAILED: JSON decode error: {e}")
raise ValueError("Invalid license format: not valid JSON")
except (ValueError, KeyError, TypeError) as e:
logger.error(f"License data validation error: {type(e).__name__}")
raise ValueError(f"Invalid license format: {type(e).__name__}")
logger.error(
f"[verify_license] FAILED: Validation error: {type(e).__name__}: {e}"
)
raise ValueError(f"Invalid license format: {type(e).__name__}: {e}")
except Exception:
logger.exception("Unexpected error during license verification")
logger.exception("[verify_license] FAILED: Unexpected error")
raise ValueError("License verification failed: unexpected error")

View File

@@ -6,6 +6,7 @@ from posthog import Posthog
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
@@ -20,7 +21,7 @@ def posthog_on_error(error: Any, items: Any) -> None:
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
debug=POSTHOG_DEBUG_LOGS_ENABLED,
on_error=posthog_on_error,
)
@@ -33,7 +34,7 @@ if MARKETING_POSTHOG_API_KEY:
marketing_posthog = Posthog(
project_api_key=MARKETING_POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
debug=POSTHOG_DEBUG_LOGS_ENABLED,
on_error=posthog_on_error,
)

View File

@@ -0,0 +1,14 @@
-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA5DpchQujdxjCwpc4/RQP
Hej6rc3SS/5ENCXL0I8NAfMogel0fqG6PKRhonyEh/Bt3P4q18y8vYzAShwf4b6Q
aS0WwshbvnkjyWlsK0BY4HLBKPkTpes7kaz8MwmPZDeelvGJ7SNv3FvyJR4QsoSQ
GSoB5iTH7hi63TjzdxtckkXoNG+GdVd/koxVDUv2uWcAoWIFTTcbKWyuq2SS/5Sf
xdVaIArqfAhLpnNbnM9OS7lZ1xP+29ZXpHxDoeluz35tJLMNBYn9u0y+puo1kW1E
TOGizlAq5kmEMsTJ55e9ZuyIV3gZAUaUKe8CxYJPkOGt0Gj6e1jHoHZCBJmaq97Y
stKj//84HNBzajaryEZuEfRecJ94ANEjkD8u9cGmW+9VxRe5544zWguP5WMT/nv1
0Q+jkOBW2hkY5SS0Rug4cblxiB7bDymWkaX6+sC0VWd5g6WXp36EuP2T0v3mYuHU
GDEiWbD44ToREPVwE/M07ny8qhLo/HYk2l8DKFt83hXe7ePBnyQdcsrVbQWOO1na
j43OkoU5gOFyOkrk2RmmtCjA8jSnw+tGCTpRaRcshqoWC1MjZyU+8/kDteXNkmv9
/B5VxzYSyX+abl7yAu5wLiUPW8l+mOazzWu0nPkmiA160ArxnRyxbGnmp4dUIrt5
azYku4tQYLSsSabfhcpeiCsCAwEAAQ==
-----END PUBLIC KEY-----

View File

@@ -97,10 +97,14 @@ def get_access_for_documents(
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The
user should have access to a document if at least one entry in the document's ACL
matches one entry in the returned set.
"""Returns a list of ACL entries that the user has access to.
This is meant to be used downstream to filter out documents that the user
does not have access to. The user should have access to a document if at
least one entry in the document's ACL matches one entry in the returned set.
NOTE: These strings must be formatted in the same way as the output of
DocumentAccess::to_acl.
"""
if user:
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}

View File

@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
"""Converts the access state to a set of formatted ACL strings.
NOTE: When querying for documents, the supplied ACL filter strings must
be formatted in the same way as this function.
"""
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:

View File

@@ -11,6 +11,7 @@ from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Protocol
from typing import Tuple
@@ -1456,6 +1457,9 @@ def get_default_admin_user_emails_() -> list[str]:
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
STATE_TOKEN_LIFETIME_SECONDS = 3600
CSRF_TOKEN_KEY = "csrftoken"
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
class OAuth2AuthorizeResponse(BaseModel):
@@ -1463,13 +1467,19 @@ class OAuth2AuthorizeResponse(BaseModel):
def generate_state_token(
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
data: Dict[str, str],
secret: SecretType,
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, secret, lifetime_seconds)
def generate_csrf_token() -> str:
return secrets.token_urlsafe(32)
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_onyx_oauth_router(
oauth_client: BaseOAuth2,
@@ -1498,6 +1508,13 @@ def get_oauth_router(
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
*,
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
csrf_token_cookie_path: str = "/",
csrf_token_cookie_domain: Optional[str] = None,
csrf_token_cookie_secure: Optional[bool] = None,
csrf_token_cookie_httponly: bool = True,
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
) -> APIRouter:
"""Generate a router with the OAuth routes."""
router = APIRouter()
@@ -1514,6 +1531,9 @@ def get_oauth_router(
route_name=callback_route_name,
)
if csrf_token_cookie_secure is None:
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
@router.get(
"/authorize",
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
@@ -1521,8 +1541,10 @@ def get_oauth_router(
)
async def authorize(
request: Request,
response: Response,
redirect: bool = Query(False),
scopes: List[str] = Query(None),
) -> OAuth2AuthorizeResponse:
) -> Response | OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
@@ -1532,9 +1554,11 @@ def get_oauth_router(
next_url = request.query_params.get("next", "/")
csrf_token = generate_csrf_token()
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
CSRF_TOKEN_KEY: csrf_token,
}
state = generate_state_token(state_data, state_secret)
@@ -1551,6 +1575,31 @@ def get_oauth_router(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
if redirect:
redirect_response = RedirectResponse(authorization_url, status_code=302)
redirect_response.set_cookie(
key=csrf_token_cookie_name,
value=csrf_token,
max_age=STATE_TOKEN_LIFETIME_SECONDS,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
secure=csrf_token_cookie_secure,
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
return redirect_response
response.set_cookie(
key=csrf_token_cookie_name,
value=csrf_token,
max_age=STATE_TOKEN_LIFETIME_SECONDS,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
secure=csrf_token_cookie_secure,
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@log_function_time(print_only=True)
@@ -1600,7 +1649,33 @@ def get_oauth_router(
try:
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
),
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(
ErrorCode,
"ACCESS_TOKEN_ALREADY_EXPIRED",
"ACCESS_TOKEN_ALREADY_EXPIRED",
),
)
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
if (
not cookie_csrf_token
or not state_csrf_token
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)

View File

@@ -0,0 +1,98 @@
# Overview of Onyx Background Jobs
The background jobs take care of:
1. Pulling/Indexing documents (from connectors)
2. Updating document metadata (from connectors)
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
4. Handling user uploaded files and deletions (from the Projects feature and uploads via the Chat)
5. Reporting metrics on things like queue length for monitoring purposes
## Worker → Queue Mapping
| Worker | File | Queues |
|--------|------|--------|
| Primary | `apps/primary.py` | `celery` |
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
| Monitoring | `apps/monitoring.py` | `monitoring` |
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
## Non-Worker Apps
| App | File | Purpose |
|-----|------|---------|
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
### Shared Module
`app_base.py` provides:
- `TenantAwareTask` - Base task class that sets tenant context
- Signal handlers for logging, cleanup, and lifecycle events
- Readiness probes and health checks
## Worker Details
### Primary (Coordinator and task dispatcher)
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
On startup:
- waits for redis, postgres, document index to all be healthy
- acquires the singleton lock
- cleans all the redis states associated with background jobs
- mark orphaned index attempts failed
Then it cycles through its tasks as scheduled by Celery Beat:
| Task | Frequency | Description |
|------|-----------|-------------|
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
See supervisord.conf for watchdog config.
### Light
Fast and short living tasks that are not resource intensive. High concurrency:
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
Tasks it handles:
- Syncs access/permissions, document sets, boosts, hidden state
- Deletes documents that are marked for deletion in Postgres
- Cleanup of checkpoints and index attempts
### Heavy
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
Generates CSV exports which may take a long time with significant data in Postgres.
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
### Docprocessing, Docfetching, User File Processing
Docprocessing and Docfetching are for indexing documents:
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
User Files come from uploads directly via the input bar
### Monitoring
Observability and metrics collections:
- Queue lengths, connector success/failure, lconnector latencies
- Memory of supervisor managed processes (workers, beat, slack)
- Cloud and multitenant specific monitorings

View File

@@ -26,10 +26,13 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.document_index.opensearch.client import (
wait_for_opensearch_with_timeout,
)
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
@@ -40,6 +43,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.tracing.braintrust_tracing import setup_braintrust_if_creds_available
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import PlainFormatter
@@ -234,6 +238,9 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
)
# Initialize Braintrust tracing in workers if credentials are available.
setup_braintrust_if_creds_available()
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
@@ -516,15 +523,17 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if ENABLE_OPENSEARCH_FOR_ONYX:
# TODO(andrei): Do some similar liveness checking for OpenSearch.
return
if not wait_for_vespa_with_timeout():
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
if not wait_for_opensearch_with_timeout():
msg = "[OpenSearch] Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)
# File for validating worker liveness
class LivenessProbe(bootsteps.StartStopStep):

View File

@@ -121,7 +121,6 @@ celery_app.autodiscover_tasks(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
@@ -134,5 +133,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)

View File

@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.pruning",
# Sandbox tasks (file sync, cleanup)
"onyx.server.features.build.sandbox.tasks",
"onyx.background.celery.tasks.hierarchyfetching",
]
)

View File

@@ -1,109 +0,0 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.kg_processing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.kg_processing",
]
)

View File

@@ -116,5 +116,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.docprocessing",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)

View File

@@ -318,12 +318,12 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -32,10 +33,16 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
doc_batch: (
Iterator[list[Document | HierarchyNode]]
| Iterator[list[SlimDocument | HierarchyNode]]
),
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
yield {
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
for doc in doc_list
}
def extract_ids_from_runnable_connector(

View File

@@ -1,21 +0,0 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_KG_PROCESSING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_KG_PROCESSING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -57,24 +57,6 @@ beat_task_templates: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-kg-processing",
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
"schedule": timedelta(seconds=60),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-kg-processing-clustering-only",
"task": OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
"schedule": timedelta(seconds=600),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
@@ -129,6 +111,15 @@ beat_task_templates: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-hierarchy-fetching",
"task": OnyxCeleryTask.CHECK_FOR_HIERARCHY_FETCHING,
"schedule": timedelta(hours=1), # Check hourly, but only fetch once per day
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
@@ -139,6 +130,27 @@ beat_task_templates: list[dict] = [
"queue": OnyxCeleryQueues.MONITORING,
},
},
# Sandbox cleanup tasks
{
"name": "cleanup-idle-sandboxes",
"task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES,
"schedule": timedelta(minutes=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.SANDBOX,
},
},
{
"name": "cleanup-old-snapshots",
"task": OnyxCeleryTask.CLEANUP_OLD_SNAPSHOTS,
"schedule": timedelta(hours=24),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.SANDBOX,
},
},
]
if ENTERPRISE_EDITION_ENABLED:

View File

@@ -87,7 +87,7 @@ from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
@@ -1436,7 +1436,7 @@ def _docprocessing_task(
callback=callback,
)
document_index = get_default_document_index(
document_indices = get_all_document_indices(
index_attempt.search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
@@ -1473,7 +1473,7 @@ def _docprocessing_task(
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
document_index=document_index,
document_indices=document_indices,
ignore_time_skip=True, # Documents are already filtered during extraction
db_session=db_session,
tenant_id=tenant_id,

View File

@@ -0,0 +1,371 @@
"""Celery tasks for hierarchy fetching.
This module provides tasks for fetching hierarchy node information from connectors.
Hierarchy nodes represent structural elements like folders, spaces, and pages that
can be used to filter search results.
The hierarchy fetching pipeline runs once per day per connector and fetches
structural information from the connector source.
"""
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import HierarchyConnector
from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
from onyx.db.connector import mark_cc_pair_as_hierarchy_fetched
from onyx.db.connector_credential_pair import (
fetch_indexable_standard_connector_credential_pair_ids,
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Hierarchy fetching runs once per day (24 hours in seconds)
HIERARCHY_FETCH_INTERVAL_SECONDS = 24 * 60 * 60
def _is_hierarchy_fetching_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if hierarchy fetching is due for this connector.
Hierarchy fetching should run once per day for active connectors.
"""
# Skip if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# Skip if connector has never successfully indexed
if not cc_pair.last_successful_index_time:
return False
# Check if we've fetched hierarchy recently
last_fetch = cc_pair.last_time_hierarchy_fetch
if last_fetch is None:
# Never fetched before - fetch now
return True
# Check if enough time has passed since last fetch
next_fetch_time = last_fetch + timedelta(seconds=HIERARCHY_FETCH_INTERVAL_SECONDS)
return datetime.now(timezone.utc) >= next_fetch_time
def _try_creating_hierarchy_fetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str,
) -> str | None:
"""Try to create a hierarchy fetching task for a connector.
Returns the task ID if created, None otherwise.
"""
LOCK_TIMEOUT = 30
# Serialize task creation attempts
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + f"hierarchy_fetching_{cc_pair.id}",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
# Refresh to get latest state
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# Generate task ID
custom_task_id = f"hierarchy_fetching_{cc_pair.id}_{uuid4()}"
# Send the task
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_HIERARCHY_FETCHING_TASK,
kwargs=dict(
cc_pair_id=cc_pair.id,
tenant_id=tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
task_id=custom_task_id,
priority=OnyxCeleryPriority.LOW,
)
if not result:
raise RuntimeError("send_task for hierarchy_fetching_task failed.")
task_logger.info(
f"Created hierarchy fetching task: "
f"cc_pair={cc_pair.id} "
f"celery_task_id={custom_task_id}"
)
return custom_task_id
except Exception:
task_logger.exception(
f"Failed to create hierarchy fetching task: cc_pair={cc_pair.id}"
)
return None
finally:
if lock.owned():
lock.release()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_HIERARCHY_FETCHING,
soft_time_limit=300,
bind=True,
)
def check_for_hierarchy_fetching(self: Task, *, tenant_id: str) -> int | None:
"""Check for connectors that need hierarchy fetching and spawn tasks.
This task runs periodically (once per day) and checks all active connectors
to see if they need hierarchy information fetched.
"""
time_start = time.monotonic()
task_logger.info("check_for_hierarchy_fetching - Starting")
tasks_created = 0
locked = False
redis_client = get_redis_client()
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_HIERARCHY_FETCHING_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# These tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
locked = True
with get_session_with_current_tenant() as db_session:
# Get all active connector credential pairs
cc_pair_ids = fetch_indexable_standard_connector_credential_pair_ids(
db_session=db_session,
active_cc_pairs_only=True,
)
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair or not _is_hierarchy_fetching_due(cc_pair):
continue
task_id = _try_creating_hierarchy_fetching_task(
celery_app=self.app,
cc_pair=cc_pair,
db_session=db_session,
r=redis_client,
tenant_id=tenant_id,
)
if task_id:
tasks_created += 1
except Exception:
task_logger.exception("check_for_hierarchy_fetching - Unexpected error")
finally:
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_hierarchy_fetching - Lock not owned on completion"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"check_for_hierarchy_fetching finished: "
f"tasks_created={tasks_created} elapsed={time_elapsed:.2f}s"
)
return tasks_created
# Batch size for hierarchy node processing
HIERARCHY_NODE_BATCH_SIZE = 100
def _run_hierarchy_extraction(
db_session: Session,
cc_pair: ConnectorCredentialPair,
source: DocumentSource,
tenant_id: str,
) -> int:
"""
Run the hierarchy extraction for a connector.
Instantiates the connector and calls load_hierarchy() if the connector
implements HierarchyConnector.
Returns the total number of hierarchy nodes extracted.
"""
connector = cc_pair.connector
credential = cc_pair.credential
# Instantiate the connector using its configured input type
runnable_connector = instantiate_connector(
db_session=db_session,
source=source,
input_type=connector.input_type,
connector_specific_config=connector.connector_specific_config,
credential=credential,
)
# Check if the connector supports hierarchy fetching
if not isinstance(runnable_connector, HierarchyConnector):
task_logger.debug(
f"Connector {source} does not implement HierarchyConnector, skipping"
)
return 0
# Determine time range: start from last hierarchy fetch, end at now
last_fetch = cc_pair.last_time_hierarchy_fetch
start_time = last_fetch.timestamp() if last_fetch else 0
end_time = datetime.now(timezone.utc).timestamp()
total_nodes = 0
node_batch: list[PydanticHierarchyNode] = []
redis_client = get_redis_client(tenant_id=tenant_id)
def _process_batch() -> int:
"""Process accumulated hierarchy nodes batch."""
if not node_batch:
return 0
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=node_batch,
source=source,
commit=True,
)
# Cache in Redis for fast ancestor resolution
cache_entries = [
HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes
]
cache_hierarchy_nodes_batch(
redis_client=redis_client,
source=source,
entries=cache_entries,
)
count = len(node_batch)
node_batch.clear()
return count
# Fetch hierarchy nodes from the connector
for node in runnable_connector.load_hierarchy(start=start_time, end=end_time):
node_batch.append(node)
if len(node_batch) >= HIERARCHY_NODE_BATCH_SIZE:
total_nodes += _process_batch()
# Process any remaining nodes
total_nodes += _process_batch()
return total_nodes
@shared_task(
name=OnyxCeleryTask.CONNECTOR_HIERARCHY_FETCHING_TASK,
soft_time_limit=3600, # 1 hour soft limit
time_limit=3900, # 1 hour 5 min hard limit
bind=True,
)
def connector_hierarchy_fetching_task(
self: Task,
*,
cc_pair_id: int,
tenant_id: str,
) -> None:
"""Fetch hierarchy information from a connector.
This task fetches structural information (folders, spaces, pages, etc.)
from the connector source and stores it in the database.
"""
task_logger.info(
f"connector_hierarchy_fetching_task starting: "
f"cc_pair={cc_pair_id} tenant={tenant_id}"
)
try:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"CC pair not found for hierarchy fetching: cc_pair={cc_pair_id}"
)
return
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
task_logger.info(
f"Skipping hierarchy fetching for deleting connector: "
f"cc_pair={cc_pair_id}"
)
return
source = cc_pair.connector.source
total_nodes = _run_hierarchy_extraction(
db_session=db_session,
cc_pair=cc_pair,
source=source,
tenant_id=tenant_id,
)
task_logger.info(
f"connector_hierarchy_fetching_task: "
f"Extracted {total_nodes} hierarchy nodes for cc_pair={cc_pair_id}"
)
# Update the last fetch time to prevent re-running until next interval
mark_cc_pair_as_hierarchy_fetched(db_session, cc_pair_id)
except Exception:
task_logger.exception(
f"connector_hierarchy_fetching_task failed: cc_pair={cc_pair_id}"
)
raise
task_logger.info(
f"connector_hierarchy_fetching_task completed: cc_pair={cc_pair_id}"
)

View File

@@ -1,77 +0,0 @@
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.apps.client import celery_app
from onyx.background.celery.tasks.kg_processing.utils import is_kg_processing_blocked
from onyx.background.celery.tasks.kg_processing.utils import (
is_kg_processing_requirements_met,
)
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
def try_creating_kg_processing_task(
tenant_id: str,
) -> bool:
"""Schedules the KG Processing for a tenant immediately. Will not schedule if
the tenant is not ready for KG processing.
"""
try:
if not is_kg_processing_requirements_met():
return False
# Send the KG processing task
result = celery_app.send_task(
OnyxCeleryTask.KG_PROCESSING,
kwargs=dict(
tenant_id=tenant_id,
),
queue=OnyxCeleryQueues.KG_PROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
task_logger.error("send_task for kg processing failed.")
return bool(result)
except Exception:
task_logger.exception(
f"try_creating_kg_processing_task - Unexpected exception for tenant={tenant_id}"
)
return False
def try_creating_kg_source_reset_task(
tenant_id: str,
source_name: str | None,
index_name: str,
) -> bool:
"""Schedules the KG Source Reset for a tenant immediately. Will not do anything if
the tenant is currently being processed.
"""
try:
if is_kg_processing_blocked():
return False
# Send the KG source reset task
result = celery_app.send_task(
OnyxCeleryTask.KG_RESET_SOURCE_INDEX,
kwargs=dict(
tenant_id=tenant_id,
source_name=source_name,
index_name=index_name,
),
queue=OnyxCeleryQueues.KG_PROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
task_logger.error("send_task for kg source reset failed.")
return bool(result)
except Exception:
task_logger.exception(
f"try_creating_kg_source_reset_task - Unexpected exception for tenant={tenant_id}"
)
return False

View File

@@ -1,253 +0,0 @@
import time
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.kg_processing.utils import (
is_kg_clustering_only_requirements_met,
)
from onyx.background.celery.tasks.kg_processing.utils import (
is_kg_processing_requirements_met,
)
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.kg.clustering.clustering import kg_clustering
from onyx.kg.extractions.extraction_processing import kg_extraction
from onyx.kg.resets.reset_source import reset_source_kg_index
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.CHECK_KG_PROCESSING,
soft_time_limit=300,
bind=True,
)
def check_for_kg_processing(self: Task, *, tenant_id: str) -> None:
"""a lightweight task used to kick off kg processing tasks."""
time_start = time.monotonic()
task_logger.warning("check_for_kg_processing - Starting")
try:
if not is_kg_processing_requirements_met():
return
task_logger.info(
f"Found documents needing KG processing for tenant {tenant_id}"
)
self.app.send_task(
OnyxCeleryTask.KG_PROCESSING,
kwargs={
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.KG_PROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during kg processing check")
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_kg_processing finished: elapsed={time_elapsed:.2f}")
@shared_task(
name=OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
soft_time_limit=300,
bind=True,
)
def check_for_kg_processing_clustering_only(self: Task, *, tenant_id: str) -> None:
"""a lightweight task used to kick off kg clustering tasks."""
time_start = time.monotonic()
task_logger.warning("check_for_kg_processing_clustering_only - Starting")
try:
if not is_kg_clustering_only_requirements_met():
return
task_logger.info(
f"Found documents needing KG clustering for tenant {tenant_id}"
)
self.app.send_task(
OnyxCeleryTask.KG_CLUSTERING_ONLY,
kwargs={
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.KG_PROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during kg clustering-only check")
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"check_for_kg_processing_clustering_only finished: elapsed={time_elapsed:.2f}"
)
@shared_task(
name=OnyxCeleryTask.KG_PROCESSING,
bind=True,
)
def kg_processing(self: Task, *, tenant_id: str) -> None:
"""a task for doing kg extraction and clustering."""
task_logger.warning(f"kg_processing - Starting for tenant {tenant_id}")
redis_client = get_redis_client()
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.KG_PROCESSING_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
try:
with get_session_with_current_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
index_str = search_settings.index_name
task_logger.info(f"KG processing set to in progress for tenant {tenant_id}")
kg_extraction(
tenant_id=tenant_id,
index_name=index_str,
lock=lock_beat,
processing_chunk_batch_size=8,
)
kg_clustering(
tenant_id=tenant_id,
index_name=index_str,
lock=lock_beat,
processing_chunk_batch_size=8,
)
except Exception:
task_logger.exception("Error during kg processing")
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"kg_processing - Lock not owned on completion: " f"tenant={tenant_id}"
)
redis_lock_dump(lock_beat, redis_client)
task_logger.debug("Completed kg processing task!")
@shared_task(
name=OnyxCeleryTask.KG_CLUSTERING_ONLY,
bind=True,
)
def kg_clustering_only(self: Task, *, tenant_id: str) -> None:
"""a task for doing kg clustering only."""
task_logger.warning(f"kg_clustering_only - Starting for tenant {tenant_id}")
redis_client = get_redis_client()
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.KG_PROCESSING_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
try:
with get_session_with_current_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
index_str = search_settings.index_name
task_logger.info(
f"KG clustering-only set to in progress for tenant {tenant_id}"
)
kg_clustering(
tenant_id=tenant_id,
index_name=index_str,
lock=lock_beat,
processing_chunk_batch_size=8,
)
except Exception:
task_logger.exception("Error during kg clustering-only")
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"kg_clustering_only - Lock not owned on completion: "
f"tenant={tenant_id}"
)
redis_lock_dump(lock_beat, redis_client)
task_logger.debug("Completed kg clustering-only task!")
@shared_task(
name=OnyxCeleryTask.KG_RESET_SOURCE_INDEX,
bind=True,
)
def kg_reset_source_index(
self: Task, *, tenant_id: str, source_name: str, index_name: str
) -> None:
"""a task for KG reset of a source."""
task_logger.warning(f"kg_reset_source_index - Starting for tenant {tenant_id}")
redis_client = get_redis_client()
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.KG_PROCESSING_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
try:
reset_source_kg_index(
source_name=source_name,
tenant_id=tenant_id,
index_name=index_name,
lock=lock_beat,
)
except Exception:
task_logger.exception("Error during kg reset")
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"kg_reset_source_index - Lock not owned on completion: "
f"tenant={tenant_id}"
)
redis_lock_dump(lock_beat, redis_client)
task_logger.debug("Completed kg reset task!")

View File

@@ -1,78 +0,0 @@
import time
from redis.lock import Lock as RedisLock
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.document import check_for_documents_needing_kg_processing
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.redis.redis_pool import get_redis_client
def is_kg_processing_blocked() -> bool:
"""Checks if there are any KG tasks in progress."""
redis_client = get_redis_client()
lock_beat: RedisLock = redis_client.lock(OnyxRedisLocks.KG_PROCESSING_LOCK)
return lock_beat.locked()
def is_kg_processing_requirements_met() -> bool:
"""Checks that there are no other KG tasks in progress, KG is enabled, valid,
and there are documents that need KG processing."""
if is_kg_processing_blocked():
return False
kg_config = get_kg_config_settings()
if not is_kg_config_settings_enabled_valid(kg_config):
return False
with get_session_with_current_tenant() as db_session:
has_staging_entities = (
db_session.query(KGEntityExtractionStaging).first() is not None
)
has_staging_relationships = (
db_session.query(KGRelationshipExtractionStaging).first() is not None
)
return (
check_for_documents_needing_kg_processing(
db_session,
kg_config.KG_COVERAGE_START_DATE,
kg_config.KG_MAX_COVERAGE_DAYS,
)
or has_staging_entities
or has_staging_relationships
)
def is_kg_clustering_only_requirements_met() -> bool:
"""Checks that there are no other KG tasks in progress, KG is enabled, valid,
and there are documents that need KG clustering."""
if is_kg_processing_blocked():
return False
kg_config = get_kg_config_settings()
if not is_kg_config_settings_enabled_valid(kg_config):
return False
# Check if there are any entries in the staging tables
with get_session_with_current_tenant() as db_session:
has_staging_entities = (
db_session.query(KGEntityExtractionStaging).first() is not None
)
has_staging_relationships = (
db_session.query(KGRelationshipExtractionStaging).first() is not None
)
return has_staging_entities or has_staging_relationships
def extend_lock(lock: RedisLock, timeout: int, last_lock_time: float) -> float:
current_time = time.monotonic()
if current_time - last_lock_time >= (timeout / 4):
lock.reacquire()
last_lock_time = current_time
return last_lock_time

View File

@@ -25,7 +25,7 @@ from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.relationships import delete_document_references_from_kg
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
@@ -97,13 +97,17 @@ def document_by_cc_pair_cleanup_task(
action = "skip"
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
# This flow is for updates and deletion so we get all indices.
document_indices = get_all_document_indices(
active_search_settings.primary,
active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
count = get_document_connector_count(db_session, document_id)
if count == 1:
@@ -113,11 +117,12 @@ def document_by_cc_pair_cleanup_task(
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
_ = retry_index.delete_single(
document_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
for retry_document_index in retry_document_indices:
_ = retry_document_index.delete_single(
document_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
delete_document_references_from_kg(
db_session=db_session,
@@ -155,14 +160,18 @@ def document_by_cc_pair_cleanup_task(
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
for retry_document_index in retry_document_indices:
# TODO(andrei): Previously there was a comment here saying
# it was ok if a doc did not exist in the document index. I
# don't agree with that claim, so keep an eye on this task
# to see if this raises.
retry_document_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(

View File

@@ -27,12 +27,13 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_active_search_settings_list
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.file_store.file_store import get_default_file_store
@@ -232,7 +233,9 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
try:
for batch in connector.load_from_state():
documents.extend(batch)
documents.extend(
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
)
adapter = UserFileIndexingAdapter(
tenant_id=tenant_id,
@@ -244,7 +247,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
search_settings=current_search_settings,
)
document_index = get_default_document_index(
# This flow is for indexing so we get all indices.
document_indices = get_all_document_indices(
current_search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
@@ -258,7 +262,7 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
document_index=document_index,
document_indices=document_indices,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
@@ -412,12 +416,16 @@ def process_single_user_file_delete(
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
# This flow is for deletion so we get all indices.
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(document_index)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
index_name = active_search_settings.primary.index_name
selection = f"{index_name}.document_id=='{user_file_id}'"
@@ -438,11 +446,12 @@ def process_single_user_file_delete(
else:
chunk_count = user_file.chunk_count
retry_index.delete_single(
doc_id=user_file_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
for retry_document_index in retry_document_indices:
retry_document_index.delete_single(
doc_id=user_file_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
file_store = get_default_file_store()
@@ -564,12 +573,16 @@ def process_single_user_file_project_sync(
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
# This flow is for updates so we get all indices.
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
@@ -579,13 +592,14 @@ def process_single_user_file_project_sync(
return None
project_ids = [project.id for project in user_file.projects]
retry_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
task_logger.info(
f"process_single_user_file_project_sync - User file id={user_file_id}"

View File

@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
DOCUMENT_SYNC_PREFIX = "documentsync"
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
TASKSET_TTL = FENCE_TTL
logger = setup_logger()
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
r.delete(DOCUMENT_SYNC_FENCE_KEY)
return
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
# Add to the tracking taskset in Redis BEFORE creating the celery task
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
# Create the Celery task
celery_app.send_task(

View File

@@ -49,7 +49,7 @@ from onyx.db.search_settings import get_active_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_document_set import RedisDocumentSet
@@ -70,6 +70,8 @@ logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
# TODO(andrei): Rename all these kinds of functions from *vespa* to a more
# generic *document_index*.
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
ignore_result=True,
@@ -465,13 +467,17 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
# This flow is for updates so we get all indices.
document_indices = get_all_document_indices(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
retry_document_indices: list[RetryDocumentIndex] = [
RetryDocumentIndex(document_index)
for document_index in document_indices
]
doc = get_document(document_id, db_session)
if not doc:
@@ -500,14 +506,18 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
# aggregated_boost_factor=doc.aggregated_boost_factor,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
for retry_document_index in retry_document_indices:
# TODO(andrei): Previously there was a comment here saying
# it was ok if a doc did not exist in the document index. I
# don't agree with that claim, so keep an eye on this task
# to see if this raises.
retry_document_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later

View File

@@ -1,18 +0,0 @@
"""Factory stub for running celery worker for knowledge graph processing.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.kg_processing import celery_app
return celery_app
app = get_app()

View File

@@ -31,17 +31,21 @@ from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.enums import ProcessingMode
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
@@ -53,7 +57,15 @@ from onyx.db.models import IndexAttempt
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.server.features.build.indexing.persistent_document_writer import (
get_persistent_document_writer,
)
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
@@ -367,6 +379,7 @@ def connector_document_extraction(
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
processing_mode = index_attempt.connector_credential_pair.processing_mode
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
from_beginning = index_attempt.from_beginning
@@ -534,9 +547,12 @@ def connector_document_extraction(
logger.info(
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
for (
document_batch,
hierarchy_node_batch,
failure,
next_checkpoint,
) in connector_runner.run(checkpoint):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
@@ -571,6 +587,33 @@ def connector_document_extraction(
if next_checkpoint:
checkpoint = next_checkpoint
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
if hierarchy_node_batch:
with get_session_with_current_tenant() as db_session:
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=hierarchy_node_batch,
source=db_connector.source,
commit=True,
)
# Cache in Redis for fast ancestor resolution during doc processing
redis_client = get_redis_client(tenant_id=tenant_id)
cache_entries = [
HierarchyNodeCacheEntry.from_db_model(node)
for node in upserted_nodes
]
cache_hierarchy_nodes_batch(
redis_client=redis_client,
source=db_connector.source,
entries=cache_entries,
)
logger.debug(
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
f"for attempt={index_attempt_id}"
)
# below is all document processing task, so if no batch we can just continue
if not document_batch:
continue
@@ -600,34 +643,103 @@ def connector_document_extraction(
logger.debug(f"Indexing batch of documents: {batch_description}")
memory_tracer.increment_and_maybe_trace()
# Store documents in storage
batch_storage.store_batch(batch_num, doc_batch_cleaned)
# cc4a
if processing_mode == ProcessingMode.FILE_SYSTEM:
# File system only - write directly to persistent storage,
# skip chunking/embedding/Vespa but still track documents in DB
# Create processing task data
processing_batch_data = {
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": batch_num, # 0-indexed
}
with get_session_with_current_tenant() as db_session:
# Create metadata for the batch
index_attempt_metadata = IndexAttemptMetadata(
attempt_id=index_attempt_id,
connector_id=db_connector.id,
credential_id=db_credential.id,
request_id=make_randomized_onyx_request_id("FSI"),
structured_id=f"{tenant_id}:{cc_pair_id}:{index_attempt_id}:{batch_num}",
batch_num=batch_num,
)
# Queue document processing task
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=docprocessing_priority,
)
# Upsert documents to PostgreSQL (document table + cc_pair relationship)
# This is a subset of what docprocessing does - just DB tracking, no chunking/embedding
index_doc_batch_prepare(
documents=doc_batch_cleaned,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
ignore_time_skip=True, # Documents already filtered during extraction
)
batch_num += 1
total_doc_batches_queued += 1
# Mark documents as indexed for the CC pair
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=db_connector.id,
credential_id=db_credential.id,
document_ids=[doc.id for doc in doc_batch_cleaned],
db_session=db_session,
)
db_session.commit()
logger.info(
f"Queued document processing batch: "
f"batch_num={batch_num} "
f"docs={len(doc_batch_cleaned)} "
f"attempt={index_attempt_id}"
)
# Write documents to persistent file system
# Use creator_id for user-segregated storage paths (sandbox isolation)
creator_id = index_attempt.connector_credential_pair.creator_id
if creator_id is None:
raise ValueError(
f"ConnectorCredentialPair {index_attempt.connector_credential_pair.id} "
"must have a creator_id for persistent document storage"
)
user_id_str: str = str(creator_id)
writer = get_persistent_document_writer(
user_id=user_id_str,
tenant_id=tenant_id,
)
written_paths = writer.write_documents(doc_batch_cleaned)
# Update coordination directly (no docprocessing task)
with get_session_with_current_tenant() as db_session:
IndexingCoordination.update_batch_completion_and_docs(
db_session=db_session,
index_attempt_id=index_attempt_id,
total_docs_indexed=len(doc_batch_cleaned),
new_docs_indexed=len(doc_batch_cleaned),
total_chunks=0, # No chunks for file system mode
)
batch_num += 1
total_doc_batches_queued += 1
logger.info(
f"Wrote documents to file system: "
f"batch_num={batch_num} "
f"docs={len(written_paths)} "
f"attempt={index_attempt_id}"
)
else:
# REGULAR mode (default): Full pipeline - store and queue docprocessing
batch_storage.store_batch(batch_num, doc_batch_cleaned)
# Create processing task data
processing_batch_data = {
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": batch_num, # 0-indexed
}
# Queue document processing task
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=docprocessing_priority,
)
batch_num += 1
total_doc_batches_queued += 1
logger.info(
f"Queued document processing batch: "
f"batch_num={batch_num} "
f"docs={len(doc_batch_cleaned)} "
f"attempt={index_attempt_id}"
)
# Check checkpoint size periodically
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
@@ -663,6 +775,24 @@ def connector_document_extraction(
total_batches=batch_num,
)
# Trigger file sync to user's sandbox (if running) - only for FILE_SYSTEM mode
# This syncs the newly written documents from S3 to any running sandbox pod
if processing_mode == ProcessingMode.FILE_SYSTEM:
creator_id = index_attempt.connector_credential_pair.creator_id
if creator_id:
app.send_task(
OnyxCeleryTask.SANDBOX_FILE_SYNC,
kwargs={
"user_id": str(creator_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.SANDBOX,
)
logger.info(
f"Triggered sandbox file sync for user {creator_id} "
f"after indexing complete"
)
except Exception as e:
logger.exception(
f"Document extraction failed: "

View File

@@ -7,6 +7,7 @@ from typing import Any
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
@@ -15,6 +16,11 @@ from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
# Full key: (document_id, chunk_ind, match_highlights)
SearchDocKey = str | tuple[str, int, tuple[str, ...]]
class ChatStateContainer:
"""Container for accumulating state during LLM loop execution.
@@ -40,6 +46,10 @@ class ChatStateContainer:
# True if this turn is a clarification question (deep research flow)
self.is_clarification: bool = False
# Note: LLM cost tracking is now handled in multi_llm.py
# Search doc collection - maps dedup key to SearchDoc for all docs from tool calls
self._all_search_docs: dict[SearchDocKey, SearchDoc] = {}
# Track which citation numbers were actually emitted during streaming
self._emitted_citations: set[int] = set()
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
"""Add a tool call to the accumulated state."""
@@ -91,6 +101,54 @@ class ChatStateContainer:
with self._lock:
return self.is_clarification
@staticmethod
def create_search_doc_key(
search_doc: SearchDoc, use_simple_key: bool = True
) -> SearchDocKey:
"""Create a unique key for a SearchDoc for deduplication.
Args:
search_doc: The SearchDoc to create a key for
use_simple_key: If True (default), use only document_id for deduplication.
If False, include chunk_ind and match_highlights so that the same
document/chunk with different highlights are stored separately.
"""
if use_simple_key:
return search_doc.document_id
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
def add_search_docs(
self, search_docs: list[SearchDoc], use_simple_key: bool = True
) -> None:
"""Add search docs to the accumulated collection with deduplication.
Args:
search_docs: List of SearchDoc objects to add
use_simple_key: If True (default), deduplicate by document_id only.
If False, deduplicate by document_id + chunk_ind + match_highlights.
"""
with self._lock:
for doc in search_docs:
key = self.create_search_doc_key(doc, use_simple_key)
if key not in self._all_search_docs:
self._all_search_docs[key] = doc
def get_all_search_docs(self) -> dict[SearchDocKey, SearchDoc]:
"""Thread-safe getter for all accumulated search docs (returns a copy)."""
with self._lock:
return self._all_search_docs.copy()
def add_emitted_citation(self, citation_num: int) -> None:
"""Add a citation number that was actually emitted during streaming."""
with self._lock:
self._emitted_citations.add(citation_num)
def get_emitted_citations(self) -> set[int]:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
func: Callable[..., None],

View File

@@ -9,12 +9,6 @@ from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from onyx.auth.users import is_user_admin
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
try_creating_kg_processing_task,
)
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
try_creating_kg_source_reset_task,
)
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import PersonaOverrideConfig
@@ -37,7 +31,6 @@ from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import check_project_ownership
from onyx.db.search_settings import get_current_search_settings
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
@@ -359,39 +352,7 @@ def process_kg_commands(
if not is_kg_config_settings_enabled_valid(kg_config_settings):
return
# get Vespa index
search_settings = get_current_search_settings(db_session)
index_str = search_settings.index_name
if message == "kg_p":
success = try_creating_kg_processing_task(tenant_id)
if success:
raise KGException("KG processing scheduled")
else:
raise KGException(
"Cannot schedule another KG processing if one is already running "
"or there are no documents to process"
)
elif message.startswith("kg_rs_source"):
msg_split = [x for x in message.split(":")]
if len(msg_split) > 2:
raise KGException("Invalid format for a source reset command")
elif len(msg_split) == 2:
source_name = msg_split[1].strip()
elif len(msg_split) == 1:
source_name = None
else:
raise KGException("Invalid format for a source reset command")
success = try_creating_kg_source_reset_task(tenant_id, source_name, index_str)
if success:
source_name = source_name or "all"
raise KGException(f"KG index reset for source '{source_name}' scheduled")
else:
raise KGException("Cannot reset index while KG processing is running")
elif message == "kg_setup":
if message == "kg_setup":
populate_missing_default_entity_types__commit(db_session=db_session)
raise KGException("KG setup done")

View File

@@ -53,6 +53,50 @@ def update_citation_processor_from_tool_response(
citation_processor.update_citation_mapping(citation_to_doc)
def extract_citation_order_from_text(text: str) -> list[int]:
"""Extract citation numbers from text in order of first appearance.
Parses citation patterns like [1], [1, 2], [[1]], 【1】 etc. and returns
the citation numbers in the order they first appear in the text.
Args:
text: The text containing citations
Returns:
List of citation numbers in order of first appearance (no duplicates)
"""
# Same pattern used in collapse_citations and DynamicCitationProcessor
# Group 2 captures the number in double bracket format: [[1]], 【【1】】
# Group 4 captures the numbers in single bracket format: [1], [1, 2]
citation_pattern = re.compile(
r"([\[【[]{2}(\d+)[\]】]]{2})|([\[【[]([\d]+(?: *, *\d+)*)[\]】]])"
)
seen: set[int] = set()
order: list[int] = []
for match in citation_pattern.finditer(text):
# Group 2 is for double bracket single number, group 4 is for single bracket
if match.group(2):
nums_str = match.group(2)
elif match.group(4):
nums_str = match.group(4)
else:
continue
for num_str in nums_str.split(","):
num_str = num_str.strip()
if num_str:
try:
num = int(num_str)
if num not in seen:
seen.add(num)
order.append(num)
except ValueError:
continue
return order
def collapse_citations(
answer_text: str,
existing_citation_mapping: CitationMapping,

View File

@@ -45,6 +45,7 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_runner import run_tool_calls
from onyx.tracing.framework.create import trace
@@ -453,12 +454,16 @@ def run_llm_loop(
# The section below calculates the available tokens for history a bit more accurately
# now that project files are loaded in.
if persona and persona.replace_base_system_prompt and persona.system_prompt:
if persona and persona.replace_base_system_prompt:
# Handles the case where user has checked off the "Replace base system prompt" checkbox
system_prompt = ChatMessageSimple(
message=persona.system_prompt,
token_count=token_counter(persona.system_prompt),
message_type=MessageType.SYSTEM,
system_prompt = (
ChatMessageSimple(
message=persona.system_prompt,
token_count=token_counter(persona.system_prompt),
message_type=MessageType.SYSTEM,
)
if persona.system_prompt
else None
)
custom_agent_prompt_msg = None
else:
@@ -612,6 +617,7 @@ def run_llm_loop(
next_citation_num=citation_processor.get_next_citation_number(),
max_concurrent_tools=None,
skip_search_query_expansion=has_called_search_tool,
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
)
tool_responses = parallel_tool_call_results.tool_responses
citation_mapping = parallel_tool_call_results.updated_citation_mapping
@@ -650,8 +656,15 @@ def run_llm_loop(
# Extract search_docs if this is a search tool response
search_docs = None
displayed_docs = None
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_docs = tool_response.rich_response.search_docs
displayed_docs = tool_response.rich_response.displayed_docs
# Add ALL search docs to state container for DB persistence
if search_docs:
state_container.add_search_docs(search_docs)
if gathered_documents:
gathered_documents.extend(search_docs)
else:
@@ -685,7 +698,7 @@ def run_llm_loop(
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
tool_call_arguments=tool_call.tool_args,
tool_call_response=saved_response,
search_docs=search_docs,
search_docs=displayed_docs or search_docs,
generated_images=generated_images,
)
# Add to state container for partial save support

View File

@@ -14,6 +14,7 @@ from onyx.chat.emitter import Emitter
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import LlmStepResult
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import ChatFileType
@@ -432,7 +433,7 @@ def translate_history_to_llm_format(
for idx, msg in enumerate(history):
# if the message is being added to the history
if msg.message_type in [
if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [
MessageType.SYSTEM,
MessageType.USER,
MessageType.ASSISTANT,
@@ -859,6 +860,11 @@ def run_llm_step_pkt_generator(
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(
result.citation_number
)
else:
# When citation_processor is None, use delta.content directly without modification
accumulated_answer += delta.content
@@ -985,6 +991,9 @@ def run_llm_step_pkt_generator(
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(result.citation_number)
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)

View File

@@ -42,7 +42,6 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import CitationDocInfo
from onyx.context.search.models import SearchDoc
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
@@ -744,27 +743,16 @@ def llm_loop_completion_handle(
else:
final_answer = "The generation was stopped by the user."
# Build citation_docs_info from accumulated citations in state container
citation_docs_info: list[CitationDocInfo] = []
seen_citation_nums: set[int] = set()
for citation_num, search_doc in state_container.citation_to_doc.items():
if citation_num not in seen_citation_nums:
seen_citation_nums.add(citation_num)
citation_docs_info.append(
CitationDocInfo(
search_doc=search_doc,
citation_number=citation_num,
)
)
save_chat_turn(
message_text=final_answer,
reasoning_tokens=state_container.reasoning_tokens,
citation_docs_info=citation_docs_info,
citation_to_doc=state_container.citation_to_doc,
tool_calls=state_container.tool_calls,
all_search_docs=state_container.get_all_search_docs(),
db_session=db_session,
assistant_message=assistant_message,
is_clarification=state_container.is_clarification,
emitted_citations=state_container.get_emitted_citations(),
)

View File

@@ -2,8 +2,9 @@ import json
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import SearchDocKey
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import CitationDocInfo
from onyx.context.search.models import SearchDoc
from onyx.db.chat import add_search_docs_to_chat_message
from onyx.db.chat import add_search_docs_to_tool_call
@@ -19,22 +20,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _create_search_doc_key(search_doc: SearchDoc) -> tuple[str, int, tuple[str, ...]]:
"""
Create a unique key for a SearchDoc that accounts for different versions of the same
document/chunk with different match_highlights.
Args:
search_doc: The SearchDoc pydantic model to create a key for
Returns:
A tuple of (document_id, chunk_ind, sorted match_highlights) that uniquely identifies
this specific version of the document
"""
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
def _create_and_link_tool_calls(
tool_calls: list[ToolCallInfo],
assistant_message: ChatMessage,
@@ -154,38 +139,36 @@ def save_chat_turn(
message_text: str,
reasoning_tokens: str | None,
tool_calls: list[ToolCallInfo],
citation_docs_info: list[CitationDocInfo],
citation_to_doc: dict[int, SearchDoc],
all_search_docs: dict[SearchDocKey, SearchDoc],
db_session: Session,
assistant_message: ChatMessage,
is_clarification: bool = False,
emitted_citations: set[int] | None = None,
) -> None:
"""
Save a chat turn by populating the assistant_message and creating related entities.
This function:
1. Updates the ChatMessage with text, reasoning tokens, and token count
2. Creates SearchDoc entries from ToolCall search_docs (for tool calls that returned documents)
3. Collects all unique SearchDocs from all tool calls and links them to ChatMessage
4. Builds citation mapping from citation_docs_info
5. Links all unique SearchDocs from tool calls to the ChatMessage
2. Creates DB SearchDoc entries from pre-deduplicated all_search_docs
3. Builds tool_call -> search_doc mapping for displayed docs
4. Builds citation mapping from citation_to_doc
5. Links all unique SearchDocs to the ChatMessage
6. Creates ToolCall entries and links SearchDocs to them
7. Builds the citations mapping for the ChatMessage
Deduplication Logic:
- SearchDocs are deduplicated using (document_id, chunk_ind, match_highlights) as the key
- This ensures that the same document/chunk with different match_highlights (from different
queries) are stored as separate SearchDoc entries
- Each ToolCall and ChatMessage will map to the correct version of the SearchDoc that
matches its specific query highlights
Args:
message_text: The message content to save
reasoning_tokens: Optional reasoning tokens for the message
tool_calls: List of tool call information to create ToolCall entries (may include search_docs)
citation_docs_info: List of citation document information for building citations mapping
citation_to_doc: Mapping from citation number to SearchDoc for building citations
all_search_docs: Pre-deduplicated search docs from ChatStateContainer
db_session: Database session for persistence
assistant_message: The ChatMessage object to populate (should already exist in DB)
is_clarification: Whether this assistant message is a clarification question (deep research flow)
emitted_citations: Set of citation numbers that were actually emitted during streaming.
If provided, only citations in this set will be saved; others are filtered out.
"""
# 1. Update ChatMessage with message content, reasoning tokens, and token count
assistant_message.message = message_text
@@ -200,53 +183,53 @@ def save_chat_turn(
else:
assistant_message.token_count = 0
# 2. Create SearchDoc entries from tool_calls
# Build mapping from SearchDoc to DB SearchDoc ID
# Use (document_id, chunk_ind, match_highlights) as key to avoid duplicates
# while ensuring different versions with different highlights are stored separately
search_doc_key_to_id: dict[tuple[str, int, tuple[str, ...]], int] = {}
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
# 2. Create DB SearchDoc entries from pre-deduplicated all_search_docs
search_doc_key_to_id: dict[SearchDocKey, int] = {}
for key, search_doc_py in all_search_docs.items():
db_search_doc = create_db_search_doc(
server_search_doc=search_doc_py,
db_session=db_session,
commit=False,
)
search_doc_key_to_id[key] = db_search_doc.id
# Process tool calls and their search docs
# 3. Build tool_call -> search_doc mapping (for displayed docs in each tool call)
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
for tool_call_info in tool_calls:
if tool_call_info.search_docs:
search_doc_ids_for_tool: list[int] = []
for search_doc_py in tool_call_info.search_docs:
# Create a unique key for this SearchDoc version
search_doc_key = _create_search_doc_key(search_doc_py)
# Check if we've already created this exact SearchDoc version
if search_doc_key in search_doc_key_to_id:
search_doc_ids_for_tool.append(search_doc_key_to_id[search_doc_key])
key = ChatStateContainer.create_search_doc_key(search_doc_py)
if key in search_doc_key_to_id:
search_doc_ids_for_tool.append(search_doc_key_to_id[key])
else:
# Create new DB SearchDoc entry
# Displayed doc not in all_search_docs - create it
# This can happen if displayed_docs contains docs not in search_docs
db_search_doc = create_db_search_doc(
server_search_doc=search_doc_py,
db_session=db_session,
commit=False,
)
search_doc_key_to_id[search_doc_key] = db_search_doc.id
search_doc_key_to_id[key] = db_search_doc.id
search_doc_ids_for_tool.append(db_search_doc.id)
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
set(search_doc_ids_for_tool)
)
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage
# Use a set to deduplicate by ID (since we've already deduplicated by key above)
all_search_doc_ids_set: set[int] = set()
for search_doc_ids in tool_call_to_search_doc_ids.values():
all_search_doc_ids_set.update(search_doc_ids)
# Collect all search doc IDs for ChatMessage linking
all_search_doc_ids_set: set[int] = set(search_doc_key_to_id.values())
# 4. Build citation mapping from citation_docs_info
# 4. Build a citation mapping from the citation number to the saved DB SearchDoc ID
# Only include citations that were actually emitted during streaming
citation_number_to_search_doc_id: dict[int, int] = {}
for citation_doc_info in citation_docs_info:
# Extract SearchDoc pydantic model
search_doc_py = citation_doc_info.search_doc
for citation_num, search_doc_py in citation_to_doc.items():
# Skip citations that weren't actually emitted (if emitted_citations is provided)
if emitted_citations is not None and citation_num not in emitted_citations:
continue
# Create the unique key for this SearchDoc version
search_doc_key = _create_search_doc_key(search_doc_py)
search_doc_key = ChatStateContainer.create_search_doc_key(search_doc_py)
# Get the search doc ID (should already exist from processing tool_calls)
if search_doc_key in search_doc_key_to_id:
@@ -283,10 +266,7 @@ def save_chat_turn(
all_search_doc_ids_set.add(db_search_doc_id)
# Build mapping from citation number to search doc ID
if citation_doc_info.citation_number is not None:
citation_number_to_search_doc_id[citation_doc_info.citation_number] = (
db_search_doc_id
)
citation_number_to_search_doc_id[citation_num] = db_search_doc_id
# 5. Link all unique SearchDocs (from both tool calls and citations) to ChatMessage
final_search_doc_ids: list[int] = list(all_search_doc_ids_set)
@@ -306,23 +286,10 @@ def save_chat_turn(
tool_call_to_search_doc_ids=tool_call_to_search_doc_ids,
)
# 7. Build citations mapping from citation_docs_info
# Any citation_doc_info with a citation_number appeared in the text and should be mapped
citations: dict[int, int] = {}
for citation_doc_info in citation_docs_info:
if citation_doc_info.citation_number is not None:
search_doc_id = citation_number_to_search_doc_id.get(
citation_doc_info.citation_number
)
if search_doc_id is not None:
citations[citation_doc_info.citation_number] = search_doc_id
else:
logger.warning(
f"Citation number {citation_doc_info.citation_number} found in citation_docs_info "
f"but no matching search doc ID in mapping"
)
assistant_message.citations = citations if citations else None
# 7. Build citations mapping - use the mapping we already built in step 4
assistant_message.citations = (
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
)
# Finally save the messages, tool calls, and docs
db_session.commit()

View File

@@ -207,9 +207,23 @@ OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost"
OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
ENABLE_OPENSEARCH_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_FOR_ONYX", "").lower() == "true"
# This is the "base" config for now, the idea is that at least for our dev
# environments we always want to be dual indexing into both OpenSearch and Vespa
# to stress test the new codepaths. Only enable this if there is some instance
# of OpenSearch running for the relevant Onyx instance.
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
@@ -399,7 +413,7 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
# Consolidated background worker (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
@@ -411,10 +425,6 @@ CELERY_WORKER_HEAVY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 2
)
CELERY_WORKER_MONITORING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_MONITORING_CONCURRENCY") or 1
)
@@ -738,6 +748,10 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
LOG_ONYX_MODEL_INTERACTIONS = (
os.environ.get("LOG_ONYX_MODEL_INTERACTIONS", "").lower() == "true"
)
PROMPT_CACHE_CHAT_HISTORY = (
os.environ.get("PROMPT_CACHE_CHAT_HISTORY", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (
@@ -1016,3 +1030,25 @@ INSTANCE_TYPE = (
## Discord Bot Configuration
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
## Stripe Configuration
# URL to fetch the Stripe publishable key from a public S3 bucket.
# Publishable keys are safe to expose publicly - they can only initialize
# Stripe.js and tokenize payment info, not make charges or access data.
STRIPE_PUBLISHABLE_KEY_URL = (
"https://onyx-stripe-public.s3.amazonaws.com/publishable-key.txt"
)
# Override for local testing with Stripe test keys (pk_test_*)
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")
# Persistent Document Storage Configuration
# When enabled, indexed documents are written to local filesystem with hierarchical structure
PERSISTENT_DOCUMENT_STORAGE_ENABLED = (
os.environ.get("PERSISTENT_DOCUMENT_STORAGE_ENABLED", "").lower() == "true"
)
# Base directory path for persistent document storage (local filesystem)
# Example: /var/onyx/indexed-docs or /app/indexed-docs
PERSISTENT_DOCUMENT_STORAGE_PATH = os.environ.get(
"PERSISTENT_DOCUMENT_STORAGE_PATH", "/app/indexed-docs"
)

View File

@@ -1,6 +1,5 @@
import os
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
NUM_RETURNED_HITS = 50

View File

@@ -80,7 +80,6 @@ POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
"celery_worker_user_file_processing"
@@ -241,6 +240,7 @@ class NotificationType(str, Enum):
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
RELEASE_NOTES = "release_notes"
ASSISTANT_FILES_READY = "assistant_files_ready"
FEATURE_ANNOUNCEMENT = "feature_announcement"
class BlobType(str, Enum):
@@ -327,6 +327,7 @@ class FileOrigin(str, Enum):
PLAINTEXT_CACHE = "plaintext_cache"
OTHER = "other"
QUERY_HISTORY_CSV = "query_history_csv"
SANDBOX_SNAPSHOT = "sandbox_snapshot"
USER_FILE = "user_file"
@@ -344,6 +345,7 @@ class MilestoneRecordType(str, Enum):
MULTIPLE_ASSISTANTS = "multiple_assistants"
CREATED_ASSISTANT = "created_assistant"
CREATED_ONYX_BOT = "created_onyx_bot"
REQUESTED_CONNECTOR = "requested_connector"
class PostgresAdvisoryLocks(Enum):
@@ -367,6 +369,7 @@ class OnyxCeleryQueues:
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
CONNECTOR_HIERARCHY_FETCHING = "connector_hierarchy_fetching"
CSV_GENERATION = "csv_generation"
# User file processing queue
@@ -380,8 +383,8 @@ class OnyxCeleryQueues:
# Monitoring queue
MONITORING = "monitoring"
# KG processing queue
KG_PROCESSING = "kg_processing"
# Sandbox processing queue
SANDBOX = "sandbox"
class OnyxRedisLocks:
@@ -389,6 +392,7 @@ class OnyxRedisLocks:
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_HIERARCHY_FETCHING_BEAT_LOCK = "da_lock:check_hierarchy_fetching_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
CHECK_INDEX_ATTEMPT_CLEANUP_BEAT_LOCK = "da_lock:check_index_attempt_cleanup_beat"
@@ -417,9 +421,6 @@ class OnyxRedisLocks:
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
# KG processing
KG_PROCESSING_LOCK = "da_lock:kg_processing"
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
@@ -431,6 +432,10 @@ class OnyxRedisLocks:
# Release notes
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
# Sandbox cleanup
CLEANUP_IDLE_SANDBOXES_BEAT_LOCK = "da_lock:cleanup_idle_sandboxes_beat"
CLEANUP_OLD_SNAPSHOTS_BEAT_LOCK = "da_lock:cleanup_old_snapshots_beat"
class OnyxRedisSignals:
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
@@ -447,9 +452,6 @@ class OnyxRedisSignals:
"signal:block_validate_connector_deletion_fences"
)
# KG processing
CHECK_KG_PROCESSING_BEAT_LOCK = "da_lock:check_kg_processing_beat"
class OnyxRedisConstants:
ACTIVE_FENCES = "active_fences"
@@ -493,6 +495,7 @@ class OnyxCeleryTask:
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
CHECK_FOR_PRUNING = "check_for_pruning"
CHECK_FOR_HIERARCHY_FETCHING = "check_for_hierarchy_fetching"
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_AUTO_LLM_UPDATE = "check_for_auto_llm_update"
@@ -534,6 +537,7 @@ class OnyxCeleryTask:
DOCPROCESSING_TASK = "docprocessing_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
CONNECTOR_HIERARCHY_FETCHING_TASK = "connector_hierarchy_fetching_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
@@ -549,12 +553,12 @@ class OnyxCeleryTask:
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
# KG processing
CHECK_KG_PROCESSING = "check_kg_processing"
KG_PROCESSING = "kg_processing"
KG_CLUSTERING_ONLY = "kg_clustering_only"
CHECK_KG_PROCESSING_CLUSTERING_ONLY = "check_kg_processing_clustering_only"
KG_RESET_SOURCE_INDEX = "kg_reset_source_index"
# Sandbox cleanup
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"
# Sandbox file sync
SANDBOX_FILE_SYNC = "sandbox_file_sync"
# this needs to correspond to the matching entry in supervisord

View File

@@ -17,6 +17,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
@@ -419,7 +420,7 @@ class AirtableConnector(LoadConnector):
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
record_documents: list[Document] = []
record_documents: list[Document | HierarchyNode] = []
# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):

View File

@@ -10,6 +10,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -56,7 +57,7 @@ class AsanaConnector(LoadConnector, PollConnector):
workspace_gid=self.workspace_id,
team_gid=self.asana_team_id,
)
docs_batch: list[Document] = []
docs_batch: list[Document | HierarchyNode] = []
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
for task in tasks:
@@ -116,5 +117,8 @@ if __name__ == "__main__":
latest_docs = connector.poll_source(one_day_ago, current)
for docs in latest_docs:
for doc in docs:
print(doc.id)
if isinstance(doc, HierarchyNode):
print("hierarchynode:", doc.display_name)
else:
print(doc.id)
logger.notice("Asana connector test completed")

View File

@@ -30,6 +30,7 @@ from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -271,9 +272,9 @@ class BitbucketConnector(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> Iterator[list[SlimDocument]]:
) -> Iterator[list[SlimDocument | HierarchyNode]]:
"""Return only document IDs for all existing pull requests."""
batch: list[SlimDocument] = []
batch: list[SlimDocument | HierarchyNode] = []
params = self._build_params(
fields=SLIM_PR_LIST_RESPONSE_FIELDS,
start=start,

View File

@@ -36,6 +36,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
@@ -377,7 +378,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
batch: list[Document] = []
batch: list[Document | HierarchyNode] = []
for page in pages:
if "Contents" not in page:
continue
@@ -616,6 +617,10 @@ if __name__ == "__main__":
for document_batch in document_batch_generator:
print("First batch of documents:")
for doc in document_batch:
if isinstance(doc, HierarchyNode):
print("hierarchynode:", doc.display_name)
continue
print(f"Document ID: {doc.id}")
print(f"Semantic Identifier: {doc.semantic_identifier}")
print(f"Source: {doc.source}")

View File

@@ -18,6 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
@@ -47,7 +48,7 @@ class BookstackConnector(LoadConnector, PollConnector):
start_ind: int,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> tuple[list[Document], int]:
) -> tuple[list[Document | HierarchyNode], int]:
params = {
"count": str(batch_size),
"offset": str(start_ind),
@@ -65,7 +66,9 @@ class BookstackConnector(LoadConnector, PollConnector):
)
batch = bookstack_client.get(endpoint, params=params).get("data", [])
doc_batch = [transformer(bookstack_client, item) for item in batch]
doc_batch: list[Document | HierarchyNode] = [
transformer(bookstack_client, item) for item in batch
]
return doc_batch, len(batch)

View File

@@ -17,6 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import TextSection
from onyx.utils.retry_wrapper import retry_builder
@@ -80,7 +81,7 @@ class ClickupConnector(LoadConnector, PollConnector):
start: int | None = None,
end: int | None = None,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
doc_batch: list[Document | HierarchyNode] = []
page: int = 0
params = {
"include_markdown_description": "true",

View File

@@ -1,4 +1,5 @@
import copy
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -46,9 +47,11 @@ from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.enums import HierarchyNodeType
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -62,6 +65,7 @@ _PAGE_EXPANSION_FIELDS = [
"space",
"metadata.labels",
"history.lastUpdated",
"ancestors", # For hierarchy node tracking
]
_ATTACHMENT_EXPANSION_FIELDS = [
"version",
@@ -133,6 +137,9 @@ class ConfluenceConnector(
self._fetched_titles: set[str] = set()
self.allow_images = False
# Track hierarchy nodes we've already yielded to avoid duplicates
self.seen_hierarchy_node_raw_ids: set[str] = set()
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
"""
@@ -183,6 +190,139 @@ class ConfluenceConnector(
logger.info(f"Setting allow_images to {value}.")
self.allow_images = value
def _yield_space_hierarchy_nodes(
self,
) -> Generator[HierarchyNode, None, None]:
"""Yield hierarchy nodes for all spaces we're indexing."""
space_keys = [self.space] if self.space else None
for space in self.confluence_client.retrieve_confluence_spaces(
space_keys=space_keys,
limit=50,
):
space_key = space.get("key")
if not space_key or space_key in self.seen_hierarchy_node_raw_ids:
continue
self.seen_hierarchy_node_raw_ids.add(space_key)
# Build space link
space_link = f"{self.wiki_base}/spaces/{space_key}"
yield HierarchyNode(
raw_node_id=space_key,
raw_parent_id=None, # Parent is SOURCE
display_name=space.get("name", space_key),
link=space_link,
node_type=HierarchyNodeType.SPACE,
)
def _yield_ancestor_hierarchy_nodes(
self,
page: dict[str, Any],
) -> Generator[HierarchyNode, None, None]:
"""Yield hierarchy nodes for all unseen ancestors of this page.
Any page that appears as an ancestor of another page IS a hierarchy node
(it has at least one child - the page we're currently processing).
This ensures parent nodes are always yielded before child documents.
"""
ancestors = page.get("ancestors", [])
space_key = page.get("space", {}).get("key")
# Ensure space is yielded first (if not already)
if space_key and space_key not in self.seen_hierarchy_node_raw_ids:
self.seen_hierarchy_node_raw_ids.add(space_key)
space = page.get("space", {})
yield HierarchyNode(
raw_node_id=space_key,
raw_parent_id=None, # Parent is SOURCE
display_name=space.get("name", space_key),
link=f"{self.wiki_base}/spaces/{space_key}",
node_type=HierarchyNodeType.SPACE,
)
# Walk through ancestors (root to immediate parent)
for i, ancestor in enumerate(ancestors):
ancestor_id = str(ancestor.get("id"))
if ancestor_id in self.seen_hierarchy_node_raw_ids:
continue
self.seen_hierarchy_node_raw_ids.add(ancestor_id)
# Determine parent of this ancestor
if i == 0:
# First ancestor - parent is the space
parent_raw_id = space_key
else:
# Parent is the previous ancestor
parent_raw_id = str(ancestors[i - 1].get("id"))
# Build link from ancestor's _links
ancestor_link = None
if "_links" in ancestor and "webui" in ancestor["_links"]:
ancestor_link = build_confluence_document_id(
self.wiki_base, ancestor["_links"]["webui"], self.is_cloud
)
yield HierarchyNode(
raw_node_id=ancestor_id,
raw_parent_id=parent_raw_id,
display_name=ancestor.get("title", f"Page {ancestor_id}"),
link=ancestor_link,
node_type=HierarchyNodeType.PAGE,
)
def _get_parent_hierarchy_raw_id(self, page: dict[str, Any]) -> str | None:
"""Get the raw hierarchy node ID of this page's parent.
Returns:
- Parent page ID if page has a parent page (last item in ancestors)
- Space key if page is at top level of space
- None if we can't determine
"""
ancestors = page.get("ancestors", [])
if ancestors:
# Last ancestor is the immediate parent page
return str(ancestors[-1].get("id"))
# Top-level page - parent is the space
return page.get("space", {}).get("key")
def _maybe_yield_page_hierarchy_node(
self, page: dict[str, Any]
) -> HierarchyNode | None:
"""Yield a hierarchy node for this page if not already yielded.
Used when a page has attachments - attachments are children of the page
in the hierarchy, so the page must be a hierarchy node.
"""
page_id = _get_page_id(page)
if page_id in self.seen_hierarchy_node_raw_ids:
return None
self.seen_hierarchy_node_raw_ids.add(page_id)
# Build page link
page_link = None
if "_links" in page and "webui" in page["_links"]:
page_link = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
# Get parent hierarchy ID
parent_raw_id = self._get_parent_hierarchy_raw_id(page)
return HierarchyNode(
raw_node_id=page_id,
raw_parent_id=parent_raw_id,
display_name=page.get("title", f"Page {page_id}"),
link=page_link,
node_type=HierarchyNodeType.PAGE,
document_id=page_link, # Page is also a document
)
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@@ -354,6 +494,9 @@ class ConfluenceConnector(
BasicExpertInfo(display_name=display_name, email=email)
)
# Determine parent hierarchy node
parent_hierarchy_raw_node_id = self._get_parent_hierarchy_raw_id(page)
# Create the document
return Document(
id=page_url,
@@ -363,6 +506,7 @@ class ConfluenceConnector(
metadata=metadata,
doc_updated_at=datetime_from_string(page["version"]["when"]),
primary_owners=primary_owners if primary_owners else None,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
except Exception as e:
logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
@@ -382,18 +526,22 @@ class ConfluenceConnector(
page: dict[str, Any],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> tuple[list[Document], list[ConnectorFailure]]:
) -> tuple[list[Document | HierarchyNode], list[ConnectorFailure]]:
"""
Inline attachments are added directly to the document as text or image sections by
this function. The returned documents/connectorfailures are for non-inline attachments
and those at the end of the page.
If there are valid attachments, the page itself is yielded as a hierarchy node
(since attachments are children of the page in the hierarchy).
"""
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
attachment_failures: list[ConnectorFailure] = []
attachment_docs: list[Document] = []
attachment_docs: list[Document | HierarchyNode] = []
page_url = ""
page_hierarchy_node_yielded = False
try:
for attachment in self.confluence_client.paginated_cql_retrieval(
@@ -487,6 +635,9 @@ class ConfluenceConnector(
BasicExpertInfo(display_name=display_name, email=email)
]
# Attachments have their parent page as the hierarchy parent
attachment_parent_hierarchy_raw_id = _get_page_id(page)
attachment_doc = Document(
id=attachment_id,
sections=sections,
@@ -500,7 +651,19 @@ class ConfluenceConnector(
else None
),
primary_owners=primary_owners,
parent_hierarchy_raw_node_id=attachment_parent_hierarchy_raw_id,
)
# If this is the first valid attachment, yield the page as a
# hierarchy node (attachments are children of the page)
if not page_hierarchy_node_yielded:
page_hierarchy_node = self._maybe_yield_page_hierarchy_node(
page
)
if page_hierarchy_node:
attachment_docs.append(page_hierarchy_node)
page_hierarchy_node_yielded = True
attachment_docs.append(attachment_doc)
except Exception as e:
logger.error(
@@ -568,7 +731,8 @@ class ConfluenceConnector(
end: SecondsSinceUnixEpoch | None = None,
) -> CheckpointOutput[ConfluenceCheckpoint]:
"""
Yields batches of Documents. For each page:
Yields batches of Documents and HierarchyNodes. For each page:
- Yield hierarchy nodes for spaces and ancestor pages (parent-before-child ordering)
- Create a Document with 1 Section for the page text/comments
- Then fetch attachments. For each attachment:
- Attempt to convert it with convert_attachment_to_content(...)
@@ -576,6 +740,10 @@ class ConfluenceConnector(
"""
checkpoint = copy.deepcopy(checkpoint)
# Yield space hierarchy nodes FIRST (only once per connector run)
if not checkpoint.next_page_url:
yield from self._yield_space_hierarchy_nodes()
# use "start" when last_updated is 0 or for confluence server
start_ts = start
page_query_url = checkpoint.next_page_url or self._build_page_retrieval_url(
@@ -592,6 +760,9 @@ class ConfluenceConnector(
limit=self.batch_size,
next_page_callback=store_next_page_url,
):
# Yield hierarchy nodes for all ancestors (parent-before-child ordering)
yield from self._yield_ancestor_hierarchy_nodes(page)
# Build doc from page
doc_or_failure = self._convert_page_to_document(page)
@@ -700,7 +871,7 @@ class ConfluenceConnector(
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = []
doc_metadata_list: list[SlimDocument | HierarchyNode] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
space_level_access_info: dict[str, ExternalAccess] = {}

View File

@@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.utils.logger import setup_logger
@@ -30,15 +31,16 @@ def batched_doc_ids(
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
for document, hierarchy_node, failure, next_checkpoint in CheckpointOutputWrapper[
CT
]()(checkpoint_connector_generator):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
# HierarchyNodes don't have IDs that need to be batched for doc processing
if len(batch) >= batch_size:
yield batch
@@ -63,7 +65,9 @@ class CheckpointOutputWrapper(Generic[CT]):
self,
checkpoint_connector_generator: CheckpointOutput[CT],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, CT | None],
tuple[
Document | None, HierarchyNode | None, ConnectorFailure | None, CT | None
],
None,
None,
]:
@@ -74,22 +78,22 @@ class CheckpointOutputWrapper(Generic[CT]):
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
for item in _inner_wrapper(checkpoint_connector_generator):
if isinstance(item, Document):
yield item, None, None, None
elif isinstance(item, HierarchyNode):
yield None, item, None, None
elif isinstance(item, ConnectorFailure):
yield None, None, item, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
raise ValueError(f"Invalid connector output type: {type(item)}")
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
yield None, None, None, self.next_checkpoint
class ConnectorRunner(Generic[CT]):
@@ -119,13 +123,27 @@ class ConnectorRunner(Generic[CT]):
self.include_permissions = include_permissions
self.doc_batch: list[Document] = []
self.hierarchy_node_batch: list[HierarchyNode] = []
def run(self, checkpoint: CT) -> Generator[
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
tuple[
list[Document] | None,
list[HierarchyNode] | None,
ConnectorFailure | None,
CT | None,
],
None,
None,
]:
"""Adds additional exception logging to the connector."""
"""
Yields batches of Documents, HierarchyNodes, failures, and checkpoints.
Returns tuples of:
- (doc_batch, None, None, None) - batch of documents
- (None, hierarchy_batch, None, None) - batch of hierarchy nodes
- (None, None, failure, None) - a connector failure
- (None, None, None, checkpoint) - new checkpoint
"""
try:
if isinstance(self.connector, CheckpointedConnector):
if self.time_range is None:
@@ -151,25 +169,47 @@ class ConnectorRunner(Generic[CT]):
)
next_checkpoint: CT | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None and isinstance(document, Document):
for (
document,
hierarchy_node,
failure,
next_checkpoint,
) in CheckpointOutputWrapper[CT]()(checkpoint_connector_generator):
if document is not None:
self.doc_batch.append(document)
if failure is not None:
yield None, failure, None
if hierarchy_node is not None:
self.hierarchy_node_batch.append(hierarchy_node)
if failure is not None:
yield None, None, failure, None
# Yield hierarchy nodes batch if it reaches batch_size
# (yield nodes before docs to maintain parent-before-child invariant)
if len(self.hierarchy_node_batch) >= self.batch_size:
yield None, self.hierarchy_node_batch, None, None
self.hierarchy_node_batch = []
# Yield document batch if it reaches batch_size
# First flush any pending hierarchy nodes to ensure parents exist
if len(self.doc_batch) >= self.batch_size:
yield self.doc_batch, None, None
if len(self.hierarchy_node_batch) > 0:
yield None, self.hierarchy_node_batch, None, None
self.hierarchy_node_batch = []
yield self.doc_batch, None, None, None
self.doc_batch = []
# yield remaining hierarchy nodes first (parents before children)
if len(self.hierarchy_node_batch) > 0:
yield None, self.hierarchy_node_batch, None, None
self.hierarchy_node_batch = []
# yield remaining documents
if len(self.doc_batch) > 0:
yield self.doc_batch, None, None
yield self.doc_batch, None, None, None
self.doc_batch = []
yield None, None, next_checkpoint
yield None, None, None, next_checkpoint
logger.debug(
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
@@ -183,18 +223,26 @@ class ConnectorRunner(Generic[CT]):
if self.time_range is None:
raise ValueError("time_range is required for PollConnector")
for document_batch in self.connector.poll_source(
for batch in self.connector.poll_source(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
):
yield document_batch, None, None
docs, nodes = self._separate_batch(batch)
if nodes:
yield None, nodes, None, None
if docs:
yield docs, None, None, None
yield None, None, finished_checkpoint
yield None, None, None, finished_checkpoint
elif isinstance(self.connector, LoadConnector):
for document_batch in self.connector.load_from_state():
yield document_batch, None, None
for batch in self.connector.load_from_state():
docs, nodes = self._separate_batch(batch)
if nodes:
yield None, nodes, None, None
if docs:
yield docs, None, None, None
yield None, None, finished_checkpoint
yield None, None, None, finished_checkpoint
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
except Exception:
@@ -219,3 +267,16 @@ class ConnectorRunner(Generic[CT]):
f"local_vars below -> \n{local_vars_str[:1024]}"
)
raise
def _separate_batch(
self, batch: list[Document | HierarchyNode]
) -> tuple[list[Document], list[HierarchyNode]]:
"""Separate a mixed batch into Documents and HierarchyNodes."""
docs: list[Document] = []
nodes: list[HierarchyNode] = []
for item in batch:
if isinstance(item, Document):
docs.append(item)
elif isinstance(item, HierarchyNode):
nodes.append(item)
return docs, nodes

View File

@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -278,7 +279,7 @@ class DiscordConnector(PollConnector, LoadConnector):
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateDocumentsOutput:
doc_batch = []
doc_batch: list[Document | HierarchyNode] = []
for doc in _manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,

View File

@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
@@ -193,7 +194,7 @@ class DiscourseConnector(PollConnector):
) -> GenerateDocumentsOutput:
page = 0
while topic_ids := self._get_latest_topics(start, end, page):
doc_batch: list[Document] = []
doc_batch: list[Document | HierarchyNode] = []
for topic_id in topic_ids:
doc_batch.append(self._get_doc_from_topic(topic_id))
if len(doc_batch) >= self.batch_size:

View File

@@ -19,6 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.retry_wrapper import retry_builder
@@ -119,7 +120,7 @@ class Document360Connector(LoadConnector, PollConnector):
workspace_id = self._get_workspace_id_by_name()
articles = self._get_articles_with_category(workspace_id)
doc_batch: List[Document] = []
doc_batch: List[Document | HierarchyNode] = []
for article in articles:
article_details = self._make_request(

View File

@@ -19,6 +19,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -80,7 +81,7 @@ class DropboxConnector(LoadConnector, PollConnector):
)
while True:
batch: list[Document] = []
batch: list[Document | HierarchyNode] = []
for entry in result.entries:
if isinstance(entry, FileMetadata):
modified_time = entry.client_modified

View File

@@ -30,6 +30,7 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
@@ -740,7 +741,7 @@ class DrupalWikiConnector(
Returns:
Generator yielding batches of SlimDocument objects.
"""
slim_docs: list[SlimDocument] = []
slim_docs: list[SlimDocument | HierarchyNode] = []
logger.info(
f"Starting retrieve_all_slim_docs with include_all_spaces={self.include_all_spaces}, spaces={self.spaces}"
)

View File

@@ -24,6 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
@@ -278,8 +279,8 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
self,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> Generator[list[Document], None, None]:
current_batch: list[Document] = []
) -> Generator[list[Document | HierarchyNode], None, None]:
current_batch: list[Document | HierarchyNode] = []
# Iterate through yielded files and filter them
for file in self._get_files_list(self.folder_path):

Some files were not shown because too many files have changed in this diff Show More