Compare commits

..

168 Commits

Author SHA1 Message Date
rkuo-danswer
d50a17db21 add filter unit tests (#4421)
* add filter unit tests

* fix tests

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-02 20:26:25 +00:00
pablonyx
dc5a1e8fd0 add more flexible vision support check (#4429) 2025-04-02 18:11:33 +00:00
pablonyx
c0b3681650 update (#4428) 2025-04-02 18:09:44 +00:00
Chris Weaver
7ec04484d4 Another fix for Salesforce perm sync (#4432)
* Another fix for Salesforce perm sync

* typing
2025-04-02 11:08:40 -07:00
Weves
1cf966ecc1 Fix Salesforce perm sync 2025-04-02 10:47:26 -07:00
rkuo-danswer
8a8526dbbb harden join function (#4424)
* harden join function

* remove log spam

* use time.monotonic

* add pid logging

* client only celery app

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-02 01:04:00 -07:00
Weves
be20586ba1 Add retries for confluence calls 2025-04-01 23:00:37 -07:00
Weves
a314462d1e Fix migrations 2025-04-01 21:48:32 -07:00
rkuo-danswer
155f53c3d7 Revert "Add user invitation test (#4161)" (#4422)
This reverts commit 806de92feb.

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-01 19:55:04 -07:00
pablonyx
7c027df186 Fix cc pair doc deletion (#4420) 2025-04-01 18:44:15 -07:00
pablonyx
0a5db96026 update (#4415) 2025-04-02 00:42:42 +00:00
joachim-danswer
daef985b02 Simpler approach (#4414) 2025-04-01 16:52:59 -07:00
Weves
b7ece296e0 Additional logging to salesforce perm sync 2025-04-01 16:19:50 -07:00
Richard Kuo (Onyx)
d7063e0a1d expose acl link feature in onyx_vespa 2025-04-01 16:19:50 -07:00
pablonyx
ee073f6d30 Tracking things (#4352) 2025-04-01 16:19:50 -07:00
Raunak Bhagat
2e524816a0 Regen (#4409)
* Edit styling of regeneration dropdown

* Finish regeneration style changes

* Remove invalid props

* Update web/src/app/chat/input/ChatInputBar.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Remove unused variables

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-04-01 16:19:50 -07:00
pablonyx
47ef0c8658 Still delete cookies (#4404) 2025-04-01 16:19:50 -07:00
pablonyx
806de92feb Add user invitation test (#4161) 2025-04-01 16:19:50 -07:00
pablonyx
da39f32fea Validate advanced fields + proper yup assurances for lists (#4399) 2025-04-01 16:19:50 -07:00
pablonyx
2a87837ce1 Very minor auth standardization (#4400) 2025-04-01 16:19:50 -07:00
pablonyx
7491cdd0f0 Update migration (#4410) 2025-04-01 16:19:50 -07:00
SubashMohan
aabd698295 refactor tests for Highspot connector to use mocking for API key retrieval (#4346) 2025-04-01 16:19:50 -07:00
Weves
4b725e4d1a Init engine in slackbot 2025-04-01 16:19:50 -07:00
rkuo-danswer
34d2d92fa8 also set permission upsert to medium priority (#4405)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-01 16:19:50 -07:00
pablonyx
3a3b2a2f8d add user files (#4152) 2025-04-01 16:19:44 -07:00
rkuo-danswer
ccd372cc4a Bugfix/slack rate limiting (#4386)
* use slack's built in rate limit handler for the bot

* WIP

* fix the slack rate limit handler

* change default to 8

* cleanup

* try catch int conversion just in case

* linearize this logic better

* code review comments

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-31 21:00:26 +00:00
evan-danswer
ea30f1de1e minor improvement to fireflies connector (#4383)
* minor improvement to fireflies connector

* reduce time diff
2025-03-31 20:00:52 +00:00
evan-danswer
a7130681d9 ensure bedrock model contains API key (#4396)
* ensure bedrock model contains API key

* fix storing bug
2025-03-31 19:58:53 +00:00
pablonyx
04911db715 fix slashes (#4259) 2025-03-31 18:08:17 +00:00
rkuo-danswer
feae7d0cc4 disambiguate job name from ee version (#4403)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-31 11:48:28 -07:00
pablonyx
ac19c64b3c temporary fix for auth (#4402) 2025-03-31 11:10:41 -07:00
pablonyx
03d5c30fd2 fix (#4372) 2025-03-31 17:25:21 +00:00
joachim-danswer
e988c13e1d Additional logging for the path from Search Results to LLM Context (#4387)
* added logging

* nit

* nit
2025-03-31 00:38:43 +00:00
pablonyx
dc18d53133 Improve multi tenant anonymous user interaction (#3857)
* cleaner handling

* k

* k

* address nits

* fix typing
2025-03-31 00:33:32 +00:00
evan-danswer
a1cef389aa fallback to ignoring unicode chars when huggingface tokenizer fails (#4394) 2025-03-30 23:45:20 +00:00
pablonyx
db8d6ce538 formatting (#4316) 2025-03-30 23:43:17 +00:00
pablonyx
e8370dcb24 Update refresh conditional (#4375)
* update refresh conditional

* k
2025-03-30 17:28:35 -07:00
pablonyx
9951fe13ba Fix image input processing without LLMs (#4390)
* quick fix

* quick fix

* Revert "quick fix"

This reverts commit 906b29bd9b.

* nit
2025-03-30 19:28:49 +00:00
evan-danswer
56f8ab927b Contextual Retrieval (#4029)
* contextual rag implementation

* WIP

* indexing test fix

* workaround for chunking errors, WIP on fixing massive memory cost

* mypy and test fixes

* reformatting

* fixed rebase
2025-03-30 18:49:09 +00:00
rkuo-danswer
cb5bbd3812 Feature/mit integration tests (#4299)
* new mit integration test template

* edit

* fix problem with ACL type tags and MIT testing for test_connector_deletion

* fix test_connector_deletion_for_overlapping_connectors

* disable some enterprise only tests in MIT version

* disable a bunch of user group / curator tests in MIT version

* wire off more tests

* typo fix

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-30 02:41:08 +00:00
Yuhong Sun
742d29e504 Remove BETA 2025-03-29 15:38:46 -07:00
SubashMohan
ecc155d082 fix: ensure base_url ends with a trailing slash (#4388) 2025-03-29 14:34:30 -07:00
pablonyx
0857e4809d fix background color 2025-03-28 16:33:30 -07:00
Chris Weaver
22e00a1f5c Fix duplicate docs (#4378)
* Initial

* Fix duplicate docs

* Add tests

* Switch to list comprehension

* Fix test
2025-03-28 22:25:26 +00:00
Chris Weaver
0d0588a0c1 Remove OnyxContext (#4376)
* Remove OnyxContext

* Fix UT

* Fix tests v2
2025-03-28 12:39:51 -07:00
rkuo-danswer
aab777f844 Bugfix/acl prefix (#4377)
* fix acl prefixing

* increase timeout a tad

* block access to init'ing DocumentAccess directly, fix test to work with ee/MIT

* fix env var checks

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-28 05:52:35 +00:00
pablonyx
babbe7689a k (#4380) 2025-03-28 02:23:45 +00:00
evan-danswer
a123661c92 fixed shared folder issue (#4371)
* fixed shared folder issue

* fix existing tests

* default allow files shared with me for service account
2025-03-27 23:39:52 +00:00
pablonyx
c554889baf Fix actions link (#4374) 2025-03-27 16:39:35 -07:00
rkuo-danswer
f08fa878a6 refactor file extension checking and add test for blob s3 (#4369)
* refactor file extension checking and add test for blob s3

* code review

* fix checking ext

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 18:57:44 +00:00
pablonyx
d307534781 add some debug logging (#4328) 2025-03-27 11:49:32 -07:00
rkuo-danswer
6f54791910 adjust some vars in real time (#4365)
* adjust some vars in real time

* some sanity checking

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 17:30:08 +00:00
pablonyx
0d5497bb6b Add multi-tenant user invitation flow test (#4360) 2025-03-27 09:53:15 -07:00
Chris Weaver
7648627503 Save all logs + add log persistence to most Onyx-owned containers (#4368)
* Save all logs + add log persistence to most Onyx-owned containers

* Separate volumes for each container

* Small fixes
2025-03-26 22:25:39 -07:00
pablonyx
927554d5ca slight robustification (#4367) 2025-03-27 03:23:36 +00:00
pablonyx
7dcec6caf5 Fix session touching (#4363)
* fix session touching

* Revert "fix session touching"

This reverts commit c473d5c9a2.

* Revert "Revert "fix session touching""

This reverts commit 26a71d40b6.

* update

* quick nit
2025-03-27 01:18:46 +00:00
rkuo-danswer
036648146d possible fix for confluence query filter (#4280)
* possible fix for confluence query filter

* nuke the attachment filter query ... it doesn't work!

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:35:14 +00:00
rkuo-danswer
2aa4697ac8 permission sync runs so often that it starves out other tasks if run at high priority (#4364)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:22:53 +00:00
rkuo-danswer
bc9b4e4f45 use slack's built in rate limit handler for the bot (#4362)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-26 21:55:04 +00:00
evan-danswer
178a64f298 fix issue with drive connector service account indexing (#4356)
* fix issue with drive connector service account indexing

* correct checkpoint resumption

* final set of fixes

* nit

* fix typing

* logging and CW comments

* nit
2025-03-26 20:54:26 +00:00
pablonyx
c79f1edf1d add a flush (#4361) 2025-03-26 14:40:52 -07:00
pablonyx
7c8e23aa54 Fix saml conversion from ext_perm -> basic (#4343)
* fix saml conversion from ext_perm -> basic

* quick nit

* minor fix

* finalize

* update

* quick fix
2025-03-26 20:36:51 +00:00
pablonyx
d37b427d52 fix email flow (#4339) 2025-03-26 18:59:12 +00:00
pablonyx
a65fefd226 test fix 2025-03-26 12:43:38 -07:00
rkuo-danswer
bb09bde519 Bugfix/google drive size threshold 2 (#4355) 2025-03-26 12:06:36 -07:00
Tim Rosenblatt
0f6cf0fc58 Fixes docker logs helper text in run-nginx.sh (#3678)
The docker container name is slightly wrong, and this commit fixes it.
2025-03-26 09:03:35 -07:00
pablonyx
fed06b592d Auto refresh credentials (#4268)
* Auto refresh credentials

* remove dupes

* clean up + tests

* k

* quick nit

* add brief comment

* misc typing
2025-03-26 01:53:31 +00:00
pablonyx
8d92a1524e fix invitation on cloud (#4351)
* fix invitation on cloud

* k
2025-03-26 01:25:17 +00:00
pablonyx
ecfea9f5ed Email formatting devices (#4353)
* update email formatting

* k

* update

* k

* nit
2025-03-25 21:42:32 +00:00
rkuo-danswer
b269f1ba06 fix broken function call (#4354)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 21:07:31 +00:00
pablonyx
30c878efa5 Quick fix (#4341)
* quick fix

* Revert "quick fix"

This reverts commit f113616276.

* smaller chnage
2025-03-25 18:39:55 +00:00
pablonyx
2024776c19 Respect contextvars when parallelizing for Google Drive (#4291)
* k

* k

* fix typing
2025-03-25 17:40:12 +00:00
pablonyx
431316929c k (#4336) 2025-03-25 17:00:35 +00:00
pablonyx
c5b9c6e308 update (#4344) 2025-03-25 16:56:23 +00:00
pablonyx
73dd188b3f update (#4338) 2025-03-25 16:55:25 +00:00
evan-danswer
79b061abbc Daylight savings time handling (#4345)
* confluence timezone improvements

* confluence timezone improvements
2025-03-25 16:11:30 +00:00
rkuo-danswer
552f1ead4f use correct namespace in redis for certain keys (#4340)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 04:10:31 +00:00
evan-danswer
17925b49e8 typing fix (#4342)
* typing fix

* changed type hint to help future coders
2025-03-25 01:01:13 +00:00
rkuo-danswer
55fb5c3ca5 add size threshold for google drive (#4329)
* add size threshold for google drive

* greptile nits

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-24 04:09:28 +00:00
evan-danswer
99546e4a4d zendesk checkpointed connector (#4311)
* zendesk v1

* logic fix

* zendesk testing

* add unit tests

* zendesk caching

* CW comments

* fix unit tests
2025-03-23 20:43:13 +00:00
pablonyx
c25d56f4a5 Improved drive flow UX (#4331)
* wip

* k

* looking good

* clenaed up

* quick nit
2025-03-23 19:21:03 +00:00
Chris Weaver
35f3f4f120 Small slack bot fixes (#4333) 2025-03-22 23:22:17 +00:00
Weves
25b69a8aca Adjust spammy log 2025-03-22 14:52:09 -07:00
pablonyx
1b7d710b2a Fix links from file metadata (#4324)
* quick fix

* clarify comment

* fix file metadata

* k
2025-03-22 18:21:47 +00:00
pablonyx
ae3d3db3f4 Update slack bot listing endpoint (#4325)
* update slack bot listing endpoint

* nit
2025-03-22 18:21:31 +00:00
evan-danswer
fb79a9e700 Checkpointed GitHub connector (#4307)
* WIP github checkpointing

* first draft of github checkpointing

* nit

* CW comments

* github basic connector test

* connector test env var

* secrets cant start with GITHUB_

* unit tests and bug fix

* connector failures

* address CW comments

* validation fix

* validation fix

* remove prints

* fixed tests

* 100 items per page
2025-03-22 01:48:05 +00:00
rkuo-danswer
587ba11bbc alembic script logging fixes (#4322)
* log fixing

* fix typos

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-22 00:50:58 +00:00
pablonyx
fce81ebb60 Minor ux nits (#4327)
* k

* quick fix
2025-03-21 21:50:56 +00:00
Chris Weaver
61facfb0a8 Fix slack connector (#4326) 2025-03-21 21:30:03 +00:00
Chris Weaver
52b96854a2 Handle move errors (#4317)
* Handle move errors

* Make a warning
2025-03-21 11:11:12 -07:00
Chris Weaver
d123713c00 Fix GPU status request in sync flow (#4318)
* Fix GPU status request in sync flow

* tweak

* Fix test

* Fix more tests
2025-03-21 11:11:00 -07:00
Chris Weaver
775c847f82 Reduce drive retries (#4312)
* Reduce drive retries

* timestamp format fix

---------

Co-authored-by: Evan Lohn <evan@danswer.ai>
2025-03-21 00:23:55 +00:00
rkuo-danswer
6d330131fd wire off image downloading for confluence and gdrive if not enabled i… (#4305)
* wire off image downloading for confluence and gdrive if not enabled in settings

* fix partial func

* fix confluence basic test

* add test for skipping/allowing images

* review comments

* skip allow images test

* mock function using the db

* mock at the proper level

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-20 23:10:28 +00:00
Chris Weaver
0292ca2445 Add option to control # of slack threads (#4310) 2025-03-20 16:56:05 +00:00
Weves
15dd1e72ca Remove slack channel validation 2025-03-20 08:34:54 -07:00
Weves
91c9be37c0 Fix loader 2025-03-20 08:30:46 -07:00
Weves
2a01c854a0 Fix cases where the bot is disabled 2025-03-20 08:30:46 -07:00
rkuo-danswer
85ebadc8eb sanitize llm keys and handle updates properly (#4270)
* sanitize llm keys and handle updates properly

* fix llm provider testing

* fix test

* mypy

* fix default model editing

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-20 01:13:02 +00:00
Chris Weaver
5dda53eec3 Notion improvement (#4306)
* Notion connector improvements

* Enable recursive index by default

* Small tweak
2025-03-19 23:16:05 +00:00
Chris Weaver
72bf427cc2 Address invalid connector state (#4304)
* Address invalid connector state

* Fixes

* Address mypy

* Address RK comment
2025-03-19 21:15:06 +00:00
Chris Weaver
f421c6010b Checkpointed Jira connector (#4286)
* Checkpointed Jira connector

* nit

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* typing improvements and test fixes

* cleaner typing

* remove default because it is from the future

* mypy

* Address EL comments

---------

Co-authored-by: evan-danswer <evan@danswer.ai>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-03-19 20:41:01 +00:00
rkuo-danswer
0b87549f35 Feature/email whitelabeling (#4260)
* work in progress

* work in progress

* WIP

* refactor, use inline attachment for image (base64 encoding doesn't work)

* pretty sure this belongs behind a multi_tenant check

* code review / refactor

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-19 13:08:44 -07:00
evan-danswer
06624a988d Gdrive checkpointed connector (#4262)
* WIP rebased

* style

* WIP, testing theory

* fix type issue

* fixed filtering bug

* fix silliness

* correct serialization and validation of threadsafedict

* concurrent drive access

* nits

* nit

* oauth bug fix

* testing fix

* fix slim retrieval

* fix integration tests

* fix testing change

* CW comments

* nit

* guarantee completion stage existence

* fix default values
2025-03-19 18:49:35 +00:00
Chris Weaver
ae774105e3 Fix slack connector creation (#4303)
* Make it fail fast + succeed validation if rate limiting is happening

* Add logging + reduce spam
2025-03-19 18:26:49 +00:00
evan-danswer
4dafc3aa6d Update README.md 2025-03-18 21:14:05 -07:00
evan-danswer
5d7d471823 Update README.md
fix bullet points
2025-03-18 19:34:08 -07:00
Weves
61366df34c Add execute permission 2025-03-18 12:03:32 -07:00
Chris Weaver
1a444245f6 Memory tracking script (#4297)
* Add simple container-level memory tracking script
2025-03-18 12:00:09 -07:00
rkuo-danswer
c32d234491 xfail highspot connector tests (#4296)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-18 11:47:17 -07:00
pablonyx
07b68436cf use ONYX_CLOUD_CELERY_TASK_PREFIX for pre provisioning (#4293) 2025-03-18 17:34:22 +00:00
Chris Weaver
293d1a4476 Add process-level memory monitoring (#4294)
* Add process-level memory monitoring

* Switch to every 5 minutes
2025-03-17 22:39:52 -07:00
SubashMohan
ba514aaaa2 Highspot connector (#4277) 2025-03-17 08:36:02 -07:00
Arun Philip
f45798b5dd add overflow-auto to show all content in Modal (#4140) 2025-03-15 11:56:19 -07:00
Weves
64ff5df083 Fix basic auth for non-ee 2025-03-14 11:40:17 -07:00
rkuo-danswer
cf1b7e7a93 add proper boolean validation to field (#4283)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-14 03:38:25 +00:00
Chris Weaver
63692a6bd3 Fix perm sync memory usage (#4282)
* Fix slack perm sync memory usage

* Make perm syncing run in batches rather than fetching everything

* Update backend/ee/onyx/external_permissions/slack/doc_sync.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update backend/ee/onyx/external_permissions/slack/doc_sync.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Loud error on slack doc sync missing permissions

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-03-14 02:26:22 +00:00
evan-danswer
934700b928 better drive url cleaning (#4247)
* better drive url cleaning

* nit

* address JR comments
2025-03-13 21:16:24 +00:00
Chris Weaver
b1a7cff9e0 Enable claude 3.7 (#4279) 2025-03-13 18:33:06 +00:00
joachim-danswer
463340b8a1 Reduce ranking scores for short chunks without actual information (#4098)
* remove title for slack

* initial working code

* simplification

* improvements

* name change to information_content_model

* avoid boost_score > 1.0

* nit

* EL comments and improvements

Improvements:
  - proper import of information content model from cache or HF
  - warm up for information content model

Other:
  - EL PR review comments

* nit

* requirements version update

* fixed docker file

* new home for model_server configs

* default off

* small updates

* YS comments - pt 1

* renaming to chunk_boost & chunk table def

* saving and deleting chunk stats in new table

* saving and updating chunk stats

* improved dict score update

* create columns for individual boost factors

* RK comments

* Update migration

* manual import reordering
2025-03-13 17:35:45 +00:00
rkuo-danswer
ba82888e1e change max workers to 2 for the moment (#4278)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-13 09:58:24 -07:00
rkuo-danswer
39465d3104 change default build info in dockerfile's to something more obviously source only (#4275)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-13 09:42:10 -07:00
rkuo-danswer
b4ecc870b9 safe handling for mediaType in confluence connector in all places (#4269)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-13 06:09:19 +00:00
rkuo-danswer
a2ac9f02fb unique constraint here doesn't work (#4271)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-12 16:25:27 -07:00
pablonyx
f87e559cc4 Separate out indexing-time image analysis into new phase (#4228)
* Separate out indexing-time image analysis into new phase

* looking good

* k

* k
2025-03-12 22:26:05 +00:00
pablonyx
5883336d5e Support image indexing customization (#4261)
* working well

* k

* ready to go

* k

* minor nits

* k

* quick fix

* k

* k
2025-03-12 20:03:45 +00:00
pablonyx
0153ff6b51 Improved logout flow (#4258)
* improved app provider modals

* improved logout flow

* k

* updates

* add docstring
2025-03-12 19:19:39 +00:00
pablonyx
2f8f0f01be Tenants on standby (#4218)
* add tenants on standby feature

* k

* fix alembic

* k

* k
2025-03-12 18:25:30 +00:00
pablonyx
a9e5ae2f11 Fix slash mystery (#4263) 2025-03-12 10:03:21 -07:00
Chris Weaver
997f40500d Add support for sandboxed salesforce (#4252) 2025-03-12 00:21:24 +00:00
rkuo-danswer
a918a84e7b fix oauth downloading and size limits in confluence (#4249)
* fix oauth downloading and size limits in confluence

* bump black to get past corrupt hash

* try working around another corrupt package

* fix raw_bytes

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-11 23:57:47 +00:00
rkuo-danswer
090f3fe817 handle conflicts on lowercasing emails (#4255)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-11 21:25:50 +00:00
pablonyx
4e70f99214 Fix slack links (#4254)
* fix slack links

* updates

* k

* nit improvements
2025-03-11 19:58:15 +00:00
pablonyx
ecbd4eb1ad add basic user invite flow (#4253) 2025-03-11 19:02:51 +00:00
pablonyx
f94d335d12 Do not show modals to non-multitenant users (#4256) 2025-03-11 11:53:13 -07:00
pablonyx
59a388ce0a fix tests 2025-03-11 11:12:35 -07:00
rkuo-danswer
9cd3cbb978 fix versions (#4250)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-10 23:50:07 -07:00
pablonyx
ab1b6b487e descrease model server logspam (#4166) 2025-03-10 18:29:27 +00:00
Chris Weaver
6ead9510a4 Small notion tweaks (#4244)
* Small notion tweaks

* Add comment
2025-03-10 15:51:12 +00:00
Chris Weaver
965f9e98bf Eliminate extremely long log line for large checkpointds (#4236)
* Eliminate extremely long log line for large checkpointds

* address greptile
2025-03-10 15:50:50 +00:00
rkuo-danswer
426883bbf5 Feature/agentic buffered (#4231)
* rename agent test script to prevent pytest autodiscovery

* first cut

* fix log message

* fix up typing

* add a sample test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-10 15:48:42 +00:00
rkuo-danswer
6ca400ced9 Bugfix/delete document tags slow (#4232)
* Add Missing Date and Message-ID Headers to Ensure Email Delivery

* fix issue Performance issue during connector deletion #4191

* fix ruff

* bump to rebuild PR

---------

Co-authored-by: ThomaciousD <2194608+ThomaciousD@users.noreply.github.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-10 03:07:30 +00:00
Weves
104c4b9f4d small modal improvement 2025-03-09 20:54:53 -07:00
pablonyx
8b5e8bd5b9 k (#4240) 2025-03-10 03:06:13 +00:00
Weves
7f7621d7c0 SMall gitbook tweaks 2025-03-09 14:46:44 -07:00
pablonyx
06dcc28d05 Improved login experience (#4178)
* functional initial auth modal

* k

* k

* k

* looking good

* k

* k

* k

* k

* update

* k

* k

* misc bunch

* improvements

* k

* address comments

* k

* nit

* update

* k
2025-03-09 01:06:20 +00:00
pablonyx
18df63dfd9 Fix local background jobs (#4241) 2025-03-08 14:47:56 -08:00
Chris Weaver
0d3c72acbf Add basic memory logging (#4234)
* Add basic memory logging

* Small tweaks

* Switch to monotonic
2025-03-08 03:49:47 +00:00
rkuo-danswer
9217243e3e Bugfix/query history notes (#4204)
* early work in progress

* rename utility script

* move actual data seeding to a shareable function

* add test

* make the test pass with the fix

* fix comment

* slight improvements and notes to query history and seeding

* update test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-07 19:52:30 +00:00
rkuo-danswer
61ccba82a9 light worker needs to discover some indexing tasks (#4209)
* light worker needs to discover some indexing tasks

* fix formatting

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-07 11:52:09 -08:00
Weves
9e8eba23c3 Fix frozen model issue 2025-03-07 09:05:43 -08:00
evan-danswer
0c29743538 use max_tokens to do better rate limit handling (#4224)
* use max_tokens to do better rate limit handling

* fix unti tests

* address greptile comment, thanks greptile
2025-03-06 18:12:05 -08:00
pablonyx
08b2421947 fix 2025-03-06 17:30:31 -08:00
pablonyx
ed518563db minor typing update 2025-03-06 17:02:39 -08:00
pablonyx
a32f7dc936 Fix Connector tests (confluence) (#4221) 2025-03-06 17:00:01 -08:00
rkuo-danswer
798e10c52f revert to always building model server (#4213)
* revert to always building model server

* fix just in case

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 23:49:45 +00:00
pablonyx
bf4983e35a Ensure consistent UX (#4222)
* ux consistent

* nit

* Update web/src/app/admin/configuration/llm/interfaces.ts

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-03-06 23:13:32 +00:00
evan-danswer
b7da91e3ae improved basic search latency (#4186)
* improved basic search latency

* address PR comments + minor cleanup
2025-03-06 22:22:59 +00:00
Weves
29382656fc Stop trying a million times for the user validity check 2025-03-06 15:35:49 -08:00
pablonyx
7d6db8d500 Comma separated list for Github repos (#4199) 2025-03-06 14:46:57 -08:00
Chris Weaver
a7a374dc81 Confluence fixes (#4220)
* Confluence fixes

* Small tweak

* Address greptile comments
2025-03-06 20:57:07 +00:00
rkuo-danswer
facc8cc2fa add scope needed for permission sync (#4198)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 20:03:38 +00:00
rkuo-danswer
2c0af0a0ca Feature/helm updates (#4201)
* add ingress for api and web

* helm setup docs

* add letsencrypt. close blocks

* use pathType ImplementationSpecific as Prefix is deprecated

* fix backend labels. configure nginx routes. update annotations

* fix linting

---------

Co-authored-by: Sajjad Anwar <sajjadkm@gmail.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 19:48:20 +00:00
pablonyx
bfbc1cd954 k (#4172) 2025-03-06 18:55:12 +00:00
pablonyx
626da583aa Fix gated tenants (#4177)
* fix

* mypy .
2025-03-06 18:07:15 +00:00
pablonyx
92faca139d Fix extra tenant mystery (#4197)
* fix extra tenant mystery

* nit
2025-03-06 18:06:49 +00:00
pablonyx
cec05c5ee9 Revert "k"
This reverts commit 687122911d.
2025-03-06 09:38:31 -08:00
Richard Kuo (Danswer)
eaf054ef06 oauth router went missing? 2025-03-05 15:50:23 -08:00
pablonyx
a7a1a24658 minor nit 2025-03-05 15:35:02 -08:00
589 changed files with 36293 additions and 14136 deletions

View File

@@ -12,29 +12,40 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# 1) Preliminary job to check if the changed files are relevant
# Bypassing this for now as the idea of not building is glitching
# releases and builds that depends on everything being tagged in docker
# 1) Preliminary job to check if the changed files are relevant
# check_model_server_changes:
# runs-on: ubuntu-latest
# outputs:
# changed: ${{ steps.check.outputs.changed }}
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
#
# - name: Check if relevant files changed
# id: check
# run: |
# # Default to "false"
# echo "changed=false" >> $GITHUB_OUTPUT
#
# # Compare the previous commit (github.event.before) to the current one (github.sha)
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# # set changed=true
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
# echo "changed=true" >> $GITHUB_OUTPUT
# fi
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.check.outputs.changed }}
changed: "true"
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
- name: Bypass check and set output
run: echo "changed=true" >> $GITHUB_OUTPUT
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'

View File

@@ -0,0 +1,209 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests-mit:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -1,6 +1,7 @@
name: Connector Tests
on:
merge_group:
pull_request:
branches: [main]
schedule:
@@ -8,6 +9,10 @@ on:
- cron: "0 16 * * *"
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
@@ -44,14 +49,21 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
# Highspot
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
@@ -76,7 +88,7 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -6,396 +6,419 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": [
"--- Individual ---"
],
"presentation": {
"group": "1",
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
}
},
{
"name": "Web / Model / API",
"configurations": [
"Web Server",
"Model Server",
"API Server",
],
"presentation": {
"group": "1",
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
}
}
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": ["--- Individual ---"],
"presentation": {
"group": "1"
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
}
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
}
}
],
"configurations": [
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": [
"run", "dev"
],
"presentation": {
"group": "2",
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": ["run", "dev"],
"presentation": {
"group": "2"
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"model_server.main:app",
"--reload",
"--port",
"9000"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Model Server Console"
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"onyx.main:app",
"--reload",
"--port",
"8080"
],
"presentation": {
"group": "2",
},
"consoleTitle": "API Server Console"
"args": ["model_server.main:app", "--reload", "--port", "9000"],
"presentation": {
"group": "2"
},
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2",
},
"consoleTitle": "Slack Bot Console"
"consoleTitle": "Model Server Console"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery primary Console"
"args": ["onyx.main:app", "--reload", "--port", "8080"],
"presentation": {
"group": "2"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery light Console"
"consoleTitle": "API Server Console"
},
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery heavy Console"
"presentation": {
"group": "2"
},
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery indexing Console"
"consoleTitle": "Slack Bot Console"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery monitoring Console"
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery"
],
"presentation": {
"group": "2"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery beat Console"
"consoleTitle": "Celery primary Console"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Pytest Console"
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
],
"presentation": {
"group": "2"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3",
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
],
"presentation": {
"group": "2"
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Pytest Console"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_containers.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3"
}
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Debug React Web App in Chrome",
"type": "chrome",
"request": "launch",
"url": "http://localhost:3000",
"webRoot": "${workspaceFolder}/web"
}
]
}
}

View File

@@ -114,3 +114,4 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -8,7 +8,7 @@ Edition features outside of personal development or testing purposes. Please rea
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.8-dev
ARG ONYX_VERSION=0.0.0-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging

View File

@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more det
visit https://github.com/onyx-dot-app/onyx."
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.8-dev
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
@@ -31,7 +31,8 @@ RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \

View File

@@ -84,7 +84,7 @@ keys = console
keys = generic
[logger_root]
level = WARN
level = INFO
handlers = console
qualname =

View File

@@ -25,6 +25,9 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)
# Alembic Config object
config = context.config
@@ -36,6 +39,7 @@ if config.config_file_name is not None and config.attributes.get(
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
@@ -64,7 +68,7 @@ def include_object(
return True
def get_schema_options() -> tuple[str, bool, bool]:
def get_schema_options() -> tuple[str, bool, bool, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -76,6 +80,10 @@ def get_schema_options() -> tuple[str, bool, bool]:
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
# continue on error with individual tenant
# only applies to online migrations
continue_on_error = x_args.get("continue", "false").lower() == "true"
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
@@ -86,14 +94,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
"Please specify a tenant-specific schema."
)
return schema_name, create_schema, upgrade_all_tenants
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
@@ -134,7 +140,12 @@ def provide_iam_token_for_alembic(
async def run_async_migrations() -> None:
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
(
schema_name,
create_schema,
upgrade_all_tenants,
continue_on_error,
) = get_schema_options()
engine = create_async_engine(
build_connection_string(),
@@ -151,9 +162,15 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
i_tenant = 0
num_tenants = len(tenant_schemas)
for schema in tenant_schemas:
i_tenant += 1
logger.info(
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
)
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
@@ -162,7 +179,12 @@ async def run_async_migrations() -> None:
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
if not continue_on_error:
logger.error("--continue is not set, raising exception!")
raise
logger.warning("--continue is set, continuing to next schema.")
else:
try:
logger.info(f"Migrating schema: {schema_name}")
@@ -180,7 +202,11 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
schema_name, _, upgrade_all_tenants = get_schema_options()
"""This doesn't really get used when we migrate in the cloud."""
logger.info("run_migrations_offline starting.")
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
@@ -230,6 +256,7 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
logger.info("run_migrations_online starting.")
asyncio.run(run_async_migrations())

View File

@@ -0,0 +1,51 @@
"""add chunk stats table
Revision ID: 3781a5eb12cb
Revises: df46c75b714e
Create Date: 2025-03-10 10:02:30.586666
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3781a5eb12cb"
down_revision = "df46c75b714e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chunk_stats",
sa.Column("id", sa.String(), primary_key=True, index=True),
sa.Column(
"document_id",
sa.String(),
sa.ForeignKey("document.id"),
nullable=False,
index=True,
),
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
sa.Column("information_content_boost", sa.Float(), nullable=True),
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
index=True,
server_default=sa.func.now(),
),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
sa.UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
op.create_index(
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
)
def downgrade() -> None:
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
op.drop_table("chunk_stats")

View File

@@ -0,0 +1,125 @@
"""Update GitHub connector repo_name to repositories
Revision ID: 3934b1bc7b62
Revises: b7c2b63c4a03
Create Date: 2025-03-05 10:50:30.516962
"""
from alembic import op
import sqlalchemy as sa
import json
import logging
# revision identifiers, used by Alembic.
revision = "3934b1bc7b62"
down_revision = "b7c2b63c4a03"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
# First get all GitHub connectors
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
# Update each connector's config
updated_count = 0
for connector_id, config in github_connectors:
try:
if not config:
logger.warning(f"Connector {connector_id} has no config, skipping")
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repo_name" not in config:
continue
# Create new config with repositories instead of repo_name
new_config = dict(config)
repo_name_value = new_config.pop("repo_name")
new_config["repositories"] = repo_name_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
updated_count += 1
except Exception as e:
logger.error(f"Error updating connector {connector_id}: {str(e)}")
def downgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
logger.debug(
"Starting rollback of GitHub connectors from repositories to repo_name"
)
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
# Revert each GitHub connector to use repo_name instead of repositories
reverted_count = 0
for connector_id, config in github_connectors:
try:
if not config:
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repositories" not in config:
continue
# Create new config with repo_name instead of repositories
new_config = dict(config)
repositories_value = new_config.pop("repositories")
new_config["repo_name"] = repositories_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"new_config": json.dumps(new_config), "connector_id": connector_id},
)
reverted_count += 1
except Exception as e:
logger.error(f"Error reverting connector {connector_id}: {str(e)}")

View File

@@ -6,8 +6,7 @@ Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
import time
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
@@ -28,357 +27,59 @@ depends_on = None
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade():
# --- PART 1: chat_message table ---
# Step 1: Add nullable column (quick, minimal locking)
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
# op.execute("DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message")
# op.execute("DROP FUNCTION IF EXISTS update_chat_message_tsv()")
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
# # Drop chat_session tsv trigger if it exists
# op.execute("DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session")
# op.execute("DROP FUNCTION IF EXISTS update_chat_session_tsv()")
# op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS title_tsv")
# raise Exception("Stop here")
time.time()
op.execute("ALTER TABLE chat_message ADD COLUMN IF NOT EXISTS message_tsv tsvector")
def upgrade() -> None:
# First, drop any existing indexes to avoid conflicts
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
# Step 2: Create function and trigger for new/updated rows
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
# Drop existing columns if they exist
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
CREATE OR REPLACE FUNCTION update_chat_message_tsv()
RETURNS TRIGGER AS $$
BEGIN
NEW.message_tsv = to_tsvector('english', NEW.message);
RETURN NEW;
END;
$$ LANGUAGE plpgsql
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Create trigger in a separate execute call
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE TRIGGER chat_message_tsv_trigger
BEFORE INSERT OR UPDATE ON chat_message
FOR EACH ROW EXECUTE FUNCTION update_chat_message_tsv()
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Step 3: Update existing rows in batches using Python
time.time()
# Get connection and count total rows
connection = op.get_bind()
total_count_result = connection.execute(
text("SELECT COUNT(*) FROM chat_message")
).scalar()
total_count = total_count_result if total_count_result is not None else 0
batch_size = 5000
batches = 0
# Calculate total batches needed
total_batches = (
(total_count + batch_size - 1) // batch_size if total_count > 0 else 0
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Process in batches - properly handling UUIDs by using OFFSET/LIMIT approach
for batch_num in range(total_batches):
offset = batch_num * batch_size
# Commit again before creating the second concurrent index
op.execute("COMMIT")
# Execute update for this batch using OFFSET/LIMIT which works with UUIDs
connection.execute(
text(
"""
UPDATE chat_message
SET message_tsv = to_tsvector('english', message)
WHERE id IN (
SELECT id FROM chat_message
WHERE message_tsv IS NULL
ORDER BY id
LIMIT :batch_size OFFSET :offset
)
"""
).bindparams(batch_size=batch_size, offset=offset)
)
# Commit each batch
connection.execute(text("COMMIT"))
# Start a new transaction
connection.execute(text("BEGIN"))
batches += 1
# Final check for any remaining NULL values
connection.execute(
text(
"""
UPDATE chat_message SET message_tsv = to_tsvector('english', message)
WHERE message_tsv IS NULL
"""
)
)
# Create GIN index concurrently
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message USING GIN (message_tsv)
"""
)
)
# First drop the trigger as it won't be needed anymore
connection.execute(
text(
"""
DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message;
"""
)
)
connection.execute(
text(
"""
DROP FUNCTION IF EXISTS update_chat_message_tsv();
"""
)
)
# Add new generated column
time.time()
connection.execute(
text(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv_gen tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
)
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv_gen
ON chat_message USING GIN (message_tsv_gen)
"""
)
)
# Drop old index and column
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;
"""
)
)
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
ALTER TABLE chat_message DROP COLUMN message_tsv;
"""
)
)
# Rename new column to old name
connection.execute(
text(
"""
ALTER TABLE chat_message RENAME COLUMN message_tsv_gen TO message_tsv;
"""
)
)
# --- PART 2: chat_session table ---
# Step 1: Add nullable column (quick, minimal locking)
time.time()
connection.execute(
text(
"ALTER TABLE chat_session ADD COLUMN IF NOT EXISTS description_tsv tsvector"
)
)
# Step 2: Create function and trigger for new/updated rows - SPLIT INTO SEPARATE CALLS
connection.execute(
text(
"""
CREATE OR REPLACE FUNCTION update_chat_session_tsv()
RETURNS TRIGGER AS $$
BEGIN
NEW.description_tsv = to_tsvector('english', COALESCE(NEW.description, ''));
RETURN NEW;
END;
$$ LANGUAGE plpgsql
"""
)
)
# Create trigger in a separate execute call
connection.execute(
text(
"""
CREATE TRIGGER chat_session_tsv_trigger
BEFORE INSERT OR UPDATE ON chat_session
FOR EACH ROW EXECUTE FUNCTION update_chat_session_tsv()
"""
)
)
# Step 3: Update existing rows in batches using Python
time.time()
# Get the maximum ID to determine batch count
# Cast id to text for MAX function since it's a UUID
max_id_result = connection.execute(
text("SELECT COALESCE(MAX(id::text), '0') FROM chat_session")
).scalar()
max_id_result if max_id_result is not None else "0"
batch_size = 5000
batches = 0
# Get all IDs ordered to process in batches
rows = connection.execute(
text("SELECT id FROM chat_session ORDER BY id")
).fetchall()
total_rows = len(rows)
# Process in batches
for batch_num, batch_start in enumerate(range(0, total_rows, batch_size)):
batch_end = min(batch_start + batch_size, total_rows)
batch_ids = [row[0] for row in rows[batch_start:batch_end]]
if not batch_ids:
continue
# Use IN clause instead of BETWEEN for UUIDs
placeholders = ", ".join([f":id{i}" for i in range(len(batch_ids))])
params = {f"id{i}": id_val for i, id_val in enumerate(batch_ids)}
# Execute update for this batch
connection.execute(
text(
f"""
UPDATE chat_session
SET description_tsv = to_tsvector('english', COALESCE(description, ''))
WHERE id IN ({placeholders})
AND description_tsv IS NULL
"""
).bindparams(**params)
)
# Commit each batch
connection.execute(text("COMMIT"))
# Start a new transaction
connection.execute(text("BEGIN"))
batches += 1
# Final check for any remaining NULL values
connection.execute(
text(
"""
UPDATE chat_session SET description_tsv = to_tsvector('english', COALESCE(description, ''))
WHERE description_tsv IS NULL
"""
)
)
# Create GIN index concurrently
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session USING GIN (description_tsv)
"""
)
)
# After Final check for chat_session
# First drop the trigger as it won't be needed anymore
connection.execute(
text(
"""
DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session;
"""
)
)
connection.execute(
text(
"""
DROP FUNCTION IF EXISTS update_chat_session_tsv();
"""
)
)
# Add new generated column
time.time()
connection.execute(
text(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv_gen tsvector
GENERATED ALWAYS AS (to_tsvector('english', COALESCE(description, ''))) STORED;
"""
)
)
# Create new index on generated column
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv_gen
ON chat_session USING GIN (description_tsv_gen)
"""
)
)
# Drop old index and column
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;
"""
)
)
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
ALTER TABLE chat_session DROP COLUMN description_tsv;
"""
)
)
# Rename new column to old name
connection.execute(
text(
"""
ALTER TABLE chat_session RENAME COLUMN description_tsv_gen TO description_tsv;
"""
)
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)

View File

@@ -5,7 +5,10 @@ Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
import logging
from typing import cast
from alembic import op
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import text
@@ -15,21 +18,45 @@ down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get database connection
"""Conflicts on lowercasing will result in the uppercased email getting a
unique integer suffix when converted to lowercase."""
connection = op.get_bind()
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
# Fetch all user emails that are not already lowercase
user_emails = connection.execute(
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
).fetchall()
for user_id, email in user_emails:
email = cast(str, email)
username, domain = email.rsplit("@", 1)
new_email = f"{username.lower()}@{domain.lower()}"
attempt = 1
while True:
try:
# Try updating the email
connection.execute(
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
{"new_email": new_email, "user_id": user_id},
)
break # Success, exit loop
except IntegrityError:
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
# Email conflict occurred, append `_1`, `_2`, etc., to the username
logger.warning(
f"Conflict while lowercasing email: "
f"old_email={email} "
f"conflicting_email={new_email} "
f"next_email={next_email}"
)
new_email = next_email
attempt += 1
def downgrade() -> None:

View File

@@ -0,0 +1,117 @@
"""duplicated no-harm user file migration
Revision ID: 6a804aeb4830
Revises: 8e1ac4f39a9f
Create Date: 2025-04-01 07:26:10.539362
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect
import datetime
# revision identifiers, used by Alembic.
revision = "6a804aeb4830"
down_revision = "8e1ac4f39a9f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Check if user_file table already exists
conn = op.get_bind()
inspector = inspect(conn)
if not inspector.has_table("user_file"):
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("link_url", sa.String(), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)
# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id",
sa.Integer(),
sa.ForeignKey("persona.id"),
primary_key=True,
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)
# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id",
sa.Integer(),
sa.ForeignKey("persona.id"),
primary_key=True,
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
)
# Update existing records to have is_user_file=False instead of NULL
op.execute(
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
)
def downgrade() -> None:
pass

View File

@@ -0,0 +1,50 @@
"""enable contextual retrieval
Revision ID: 8e1ac4f39a9f
Revises: 9aadf32dfeb4
Create Date: 2024-12-20 13:29:09.918661
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e1ac4f39a9f"
down_revision = "9aadf32dfeb4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings",
sa.Column(
"enable_contextual_rag",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_name",
sa.String(),
nullable=True,
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_provider",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("search_settings", "enable_contextual_rag")
op.drop_column("search_settings", "contextual_rag_llm_name")
op.drop_column("search_settings", "contextual_rag_llm_provider")

View File

@@ -0,0 +1,113 @@
"""add user files
Revision ID: 9aadf32dfeb4
Revises: 3781a5eb12cb
Create Date: 2025-01-26 16:08:21.551022
"""
import sqlalchemy as sa
import datetime
from alembic import op
# revision identifiers, used by Alembic.
revision = "9aadf32dfeb4"
down_revision = "3781a5eb12cb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("link_url", sa.String(), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)
# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)
# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
)
# Update existing records to have is_user_file=False instead of NULL
op.execute(
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
)
def downgrade() -> None:
# Drop the persona__user_folder table
op.drop_table("persona__user_folder")
# Drop the persona__user_file table
op.drop_table("persona__user_file")
# Drop the user_file table
op.drop_table("user_file")
# Drop the user_folder table
op.drop_table("user_folder")
op.drop_column("connector_credential_pair", "is_user_file")

View File

@@ -0,0 +1,36 @@
"""add_default_vision_provider_to_llm_provider
Revision ID: df46c75b714e
Revises: 3934b1bc7b62
Create Date: 2025-03-11 16:20:19.038945
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "df46c75b714e"
down_revision = "3934b1bc7b62"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column(
"is_default_vision_provider",
sa.Boolean(),
nullable=True,
server_default=sa.false(),
),
)
op.add_column(
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("llm_provider", "default_vision_model")
op.drop_column("llm_provider", "is_default_vision_provider")

View File

@@ -0,0 +1,50 @@
"""add prompt length limit
Revision ID: f71470ba9274
Revises: 6a804aeb4830
Create Date: 2025-04-01 15:07:14.977435
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f71470ba9274"
down_revision = "6a804aeb4830"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=8000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=8000),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=8000),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=8000),
type_=sa.TEXT(),
existing_nullable=False,
)

View File

@@ -0,0 +1,77 @@
"""updated constraints for ccpairs
Revision ID: f7505c5b0284
Revises: f71470ba9274
Create Date: 2025-04-01 17:50:42.504818
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f7505c5b0284"
down_revision = "f71470ba9274"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Drop the old foreign-key constraints
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# 2) Re-add them with ondelete='CASCADE'
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="connector",
local_cols=["connector_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="credential",
local_cols=["credential_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Reverse the changes for rollback
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# Recreate without CASCADE
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
"connector",
["connector_id"],
["id"],
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
"credential",
["credential_id"],
["id"],
)

View File

@@ -0,0 +1,33 @@
"""add new available tenant table
Revision ID: 3b45e0018bf1
Revises: ac842f85f932
Create Date: 2025-03-06 09:55:18.229910
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "3b45e0018bf1"
down_revision = "ac842f85f932"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create new_available_tenant table
op.create_table(
"available_tenant",
sa.Column("tenant_id", sa.String(), nullable=False),
sa.Column("alembic_version", sa.String(), nullable=False),
sa.Column("date_created", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
)
def downgrade() -> None:
# Drop new_available_tenant table
op.drop_table("available_tenant")

View File

@@ -0,0 +1,51 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -93,12 +93,12 @@ def _get_access_for_documents(
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
is_public=is_public_anywhere,
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
)
return access_map

View File

@@ -2,7 +2,6 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
@@ -32,8 +31,6 @@ def gather_stream_for_answer_api(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
elif isinstance(packet, OnyxContexts):
response.contexts = packet
if answer:
response.answer = answer

View File

@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app

View File

@@ -27,6 +27,8 @@ def get_empty_chat_messages_entries__paginated(
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
Only messages of type USER are returned
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],

View File

@@ -2,6 +2,7 @@
Rules defined here:
https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html
"""
from collections.abc import Generator
from typing import Any
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
@@ -263,13 +264,11 @@ def _fetch_all_page_restrictions(
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
For all pages, if a page has restrictions, then use those restrictions.
Otherwise, use the space's restrictions.
"""
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
@@ -286,11 +285,9 @@ def _fetch_all_page_restrictions(
confluence_client=confluence_client,
perm_sync_data=slim_doc.perm_sync_data,
):
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
# If there are restrictions, then we don't need to use the space's restrictions
continue
@@ -324,11 +321,9 @@ def _fetch_all_page_restrictions(
continue
# If there are no restrictions, then use the space's restrictions
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
if (
not space_permissions.is_public
@@ -342,13 +337,12 @@ def _fetch_all_page_restrictions(
)
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -387,7 +381,7 @@ def confluence_doc_sync(
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions(
yield from _fetch_all_page_restrictions(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,

View File

@@ -1,3 +1,4 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
@@ -34,7 +35,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -48,7 +49,6 @@ def gmail_doc_sync(
cc_pair, gmail_connector, callback=callback
)
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
@@ -60,17 +60,14 @@ def gmail_doc_sync(
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue
if user_email := slim_doc.perm_sync_data.get("user_email"):
ext_access = ExternalAccess(
external_user_emails=set([user_email]),
external_user_group_ids=set(),
is_public=False,
)
document_external_access.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)
return document_external_access

View File

@@ -1,3 +1,4 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -147,7 +148,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -161,7 +162,6 @@ def gdrive_doc_sync(
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
@@ -174,10 +174,7 @@ def gdrive_doc_sync(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,
)
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
)
yield DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
)
return document_external_accesses

View File

@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
# if user is None, permissions are not enforced
return chunks
chunks_to_keep = []
final_chunk_dict: dict[str, InferenceChunk] = {}
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)
final_chunk_dict[chunk.unique_id] = chunk
# For each source, filter out the chunks using the permission
# check function for that source
@@ -79,6 +79,16 @@ def _post_query_chunk_censoring(
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
return chunks_to_keep
for censored_chunk in censored_chunks:
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
final_chunk_list: list[InferenceChunk] = []
for chunk in chunks:
# only if the chunk is in the final censored chunks, add it to the final list
# if it is missing, that means it was intentionally left out
if chunk.unique_id in final_chunk_dict:
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
return final_chunk_list

View File

@@ -58,6 +58,7 @@ def _get_objects_access_for_user_email_from_salesforce(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
logger.warning(f"User '{user_email}' not found in Salesforce")
return None
# This is the only query that is not cached in the function
@@ -65,6 +66,7 @@ def _get_objects_access_for_user_email_from_salesforce(
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
logger.debug(f"Object ID to access: {object_id_to_access}")
return object_id_to_access

View File

@@ -42,11 +42,18 @@ def get_any_salesforce_client_for_doc_id(
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
if len(result["records"]) > 0:
return result["records"][0]["Id"]
# try emails
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
return None
# This contains only the user_ids that we have found in Salesforce.

View File

@@ -1,3 +1,5 @@
from collections.abc import Generator
from slack_sdk import WebClient
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
@@ -14,35 +16,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
if channel_id not in channel_doc_map:
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
return channel_doc_map
def _fetch_workspace_permissions(
user_id_to_email_map: dict[str, str],
) -> ExternalAccess:
@@ -122,10 +95,37 @@ def _fetch_channel_permissions(
return channel_permissions
def _get_slack_document_access(
cc_pair: ConnectorCredentialPair,
channel_permissions: dict[str, ExternalAccess],
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
yield DocExternalAccess(
external_access=channel_permissions[channel_id],
doc_id=doc_metadata.id,
)
if callback:
if callback.should_stop():
raise RuntimeError("_get_slack_document_access: Stop signal detected")
callback.progress("_get_slack_document_access", 1)
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -136,9 +136,12 @@ def slack_doc_sync(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair, callback=callback
)
if not user_id_to_email_map:
raise ValueError(
"No user id to email map found. Please check to make sure that "
"your Slack bot token has the `users:read.email` scope"
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,
)
@@ -148,18 +151,8 @@ def slack_doc_sync(
user_id_to_email_map=user_id_to_email_map,
)
document_external_accesses = []
for channel_id, ext_access in channel_permissions.items():
doc_ids = channel_doc_map.get(channel_id)
if not doc_ids:
# No documents found for channel the channel_id
continue
for doc_id in doc_ids:
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=doc_id,
)
)
return document_external_accesses
yield from _get_slack_document_access(
cc_pair=cc_pair,
channel_permissions=channel_permissions,
callback=callback,
)

View File

@@ -1,7 +1,10 @@
from collections.abc import Callable
from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
@@ -23,7 +26,7 @@ DocSyncFuncType = Callable[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
Generator[DocExternalAccess, None, None],
]
GroupSyncFuncType = Callable[
@@ -65,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth.api import router as oauth_router
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
# Use the configured scopes
base_scopes=oidc_scopes,
),
auth_backend,
USER_AUTH_SECRET,
@@ -128,7 +146,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, oauth_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -15,8 +15,8 @@ from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME
from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME
from ee.onyx.server.enterprise_settings.store import get_logo_filename
from ee.onyx.server.enterprise_settings.store import get_logotype_filename
from ee.onyx.server.enterprise_settings.store import load_analytics_script
from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
@@ -28,7 +28,7 @@ from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.utils.logger import setup_logger
admin_router = APIRouter(prefix="/admin/enterprise-settings")
@@ -131,31 +131,49 @@ def put_logo(
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response:
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME
file_io = file_store.read_file(filename, mode="b")
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
raise HTTPException(
status_code=404,
detail=f"No {'logotype' if is_logotype else 'logo'} file found",
detail="No logo file found",
)
else:
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
raise HTTPException(
status_code=404,
detail="No logotype file found",
)
else:
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
@basic_router.get("/logotype")
def fetch_logotype(db_session: Session = Depends(get_session)) -> Response:
return fetch_logo_or_logotype(is_logotype=True, db_session=db_session)
return fetch_logotype_helper(db_session)
@basic_router.get("/logo")
def fetch_logo(
is_logotype: bool = False, db_session: Session = Depends(get_session)
) -> Response:
return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session)
if is_logotype:
return fetch_logotype_helper(db_session)
return fetch_logo_helper(db_session)
@admin_router.put("/custom-analytics-script")

View File

@@ -13,6 +13,7 @@ from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY
from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -21,8 +22,18 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_LOGO_FILENAME = "__logo__"
_LOGOTYPE_FILENAME = "__logotype__"
def load_settings() -> EnterpriseSettings:
"""Loads settings data directly from DB. This should be used primarily
for checking what is actually in the DB, aka for editing and saving back settings.
Runtime settings actually used by the application should be checked with
load_runtime_settings as defaults may be applied at runtime.
"""
dynamic_config_store = get_kv_store()
try:
settings = EnterpriseSettings(
@@ -36,9 +47,24 @@ def load_settings() -> EnterpriseSettings:
def store_settings(settings: EnterpriseSettings) -> None:
"""Stores settings directly to the kv store / db."""
get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
def load_runtime_settings() -> EnterpriseSettings:
"""Loads settings from DB and applies any defaults or transformations for use
at runtime.
Should not be stored back to the DB.
"""
enterprise_settings = load_settings()
if not enterprise_settings.application_name:
enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME
return enterprise_settings
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
@@ -60,10 +86,6 @@ def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> No
get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script)
_LOGO_FILENAME = "__logo__"
_LOGOTYPE_FILENAME = "__logotype__"
def is_valid_file_type(filename: str) -> bool:
valid_extensions = (".png", ".jpg", ".jpeg")
return filename.endswith(valid_extensions)
@@ -116,3 +138,11 @@ def upload_logo(
file_type=file_type,
)
return True
def get_logo_filename() -> str:
return _LOGO_FILENAME
def get_logotype_filename() -> str:
return _LOGOTYPE_FILENAME

View File

@@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
3) Reset token cookie
3) The anonymous user cookie
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
@@ -52,41 +52,55 @@ async def _get_tenant_id_from_request(
if tenant_id is not None:
return tenant_id
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
if token_data:
tenant_id_from_payload = token_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else None
)
# Since token_data.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else POSTGRES_DEFAULT_SCHEMA
if tenant_id and not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(
anonymous_user_cookie
)
tenant_id = anonymous_user_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
)
if not tenant_id or not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
return tenant_id
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")

View File

@@ -80,6 +80,7 @@ class ConfluenceCloudOAuth:
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"read:content-details:confluence%20" # for permission sync
"offline_access"
)

View File

@@ -1,26 +1,35 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -46,25 +55,6 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _translate_doc_response_to_simple_doc(
doc_response: QADocsResponse,
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,
match_highlights=[
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
@@ -89,14 +79,26 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -113,11 +115,104 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,3 +1,5 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -6,9 +8,9 @@ from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswer
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -88,6 +90,64 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -103,10 +163,14 @@ class ChatBasicResponse(BaseModel):
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
# TODO: deprecate both of these
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits
@@ -153,4 +217,3 @@ class OneShotQAResponse(BaseModel):
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
contexts: OnyxContexts | None = None

View File

@@ -48,10 +48,15 @@ def fetch_and_process_chat_session_history(
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session
# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
@@ -246,6 +251,8 @@ def get_query_history_as_csv(
detail="Query history has been disabled by the administrator.",
)
# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),

View File

@@ -38,6 +38,7 @@ router = APIRouter(prefix="/auth/saml")
async def upsert_saml_user(email: str) -> User:
logger.debug(f"Attempting to upsert SAML user with email: {email}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
@@ -48,9 +49,13 @@ async def upsert_saml_user(email: str) -> User:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:
return await user_manager.get_by_email(email)
user = await user_manager.get_by_email(email)
# If user has a non-authenticated role, treat as non-existent
if not user.role.is_web_login():
raise exceptions.UserNotExists()
return user
except exceptions.UserNotExists:
logger.notice("Creating user from SAML login")
logger.info("Creating user from SAML login")
user_count = await get_user_count()
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
@@ -59,11 +64,10 @@ async def upsert_saml_user(email: str) -> User:
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user: User = await user_manager.create(
user = await user_manager.create(
UserCreate(
email=email,
password=hashed_pass,
is_verified=True,
role=role,
)
)

View File

@@ -0,0 +1,45 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -0,0 +1,98 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,269 +1,24 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)

View File

@@ -0,0 +1,96 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,3 +67,30 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -48,4 +48,5 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}

View File

@@ -4,6 +4,7 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -14,6 +15,7 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -26,11 +28,12 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
@@ -55,43 +58,77 @@ logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
"""
Get existing tenant ID for an email or create a new tenant if none exists.
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
# Early return for non-multi-tenant mode
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
# First, check if the user already has a tenant
tenant_id: str | None = None
try:
tenant_id = get_tenant_id_for_email(email)
return tenant_id
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
# User doesn't exist, so we need to create a new tenant or assign an existing one
pass
try:
# Try to get a pre-provisioned tenant
tenant_id = await get_available_tenant()
if tenant_id:
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
else:
# If no pre-provisioned tenant is available, create a new one on-demand
tenant_id = await create_tenant(email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
if not tenant_id:
# Notify control plane if we have created / assigned a new tenant
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
return tenant_id
except Exception as e:
# If we've encountered an error, log and raise an exception
error_msg = "Failed to provision tenant"
logger.error(error_msg, exc_info=e)
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
status_code=500,
detail="Failed to provision tenant. Please try again later.",
)
return tenant_id
async def create_tenant(email: str, referral_source: str | None = None) -> str:
"""
Create a new tenant on-demand when no pre-provisioned tenants are available.
This is the fallback method when we can't use a pre-provisioned tenant.
"""
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
logger.info(f"Creating new tenant {tenant_id} for user {email}")
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
try:
await rollback_tenant_provisioning(tenant_id)
except Exception:
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
@@ -105,54 +142,25 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
# Create the schema for the tenant
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Set up the tenant with all necessary configurations
await setup_tenant(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
# Assign the tenant to the user
await assign_tenant_to_user(tenant_id, email)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(
@@ -183,20 +191,74 @@ async def notify_control_plane(
async def rollback_tenant_provisioning(tenant_id: str) -> None:
# Logic to rollback tenant provisioning on data plane
"""
Logic to rollback tenant provisioning on data plane.
Handles each step independently to ensure maximum cleanup even if some steps fail.
"""
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
# Track if any part of the rollback fails
rollback_errors = []
# 1. Try to drop the tenant's schema
try:
drop_schema(tenant_id)
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
except Exception as e:
logger.error(f"Failed to rollback tenant provisioning: {e}")
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 2. Try to remove tenant mapping
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
logger.info(
f"Successfully removed user mappings for tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 3. If this tenant was in the available tenants table, remove it
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
available_tenant = (
db_session.query(AvailableTenant)
.filter(AvailableTenant.tenant_id == tenant_id)
.first()
)
if available_tenant:
db_session.delete(available_tenant)
db_session.commit()
logger.info(
f"Removed tenant {tenant_id} from available tenants table"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# Log summary of rollback operation
if rollback_errors:
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
else:
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
def configure_default_api_keys(db_session: Session) -> None:
@@ -209,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
@@ -221,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None:
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
openai_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
@@ -229,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
full_provider = upsert_llm_provider(openai_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
@@ -349,3 +413,154 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None
async def get_available_tenant() -> str | None:
"""
Get an available pre-provisioned tenant from the NewAvailableTenant table.
Returns the tenant_id if one is available, None otherwise.
Uses row-level locking to prevent race conditions when multiple processes
try to get an available tenant simultaneously.
"""
if not MULTI_TENANT:
return None
with get_session_with_shared_schema() as db_session:
try:
db_session.begin()
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
available_tenant = (
db_session.query(AvailableTenant)
.order_by(AvailableTenant.date_created)
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
.first()
)
if available_tenant:
tenant_id = available_tenant.tenant_id
# Remove the tenant from the available tenants table
db_session.delete(available_tenant)
db_session.commit()
logger.info(f"Using pre-provisioned tenant {tenant_id}")
return tenant_id
else:
db_session.rollback()
return None
except Exception:
logger.exception("Error getting available tenant")
db_session.rollback()
return None
async def setup_tenant(tenant_id: str) -> None:
"""
Set up a tenant with all necessary configurations.
This is a centralized function that handles all tenant setup logic.
"""
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations in a way that isolates it from the current event loop
# Create a new event loop for this synchronous operation
loop = asyncio.get_event_loop()
# Use run_in_executor which properly isolates the thread execution
await loop.run_in_executor(None, lambda: run_alembic_migrations(tenant_id))
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Configure default API keys
configure_default_api_keys(db_session)
# Set up Onyx with appropriate settings
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
except Exception as e:
logger.exception(f"Failed to set up tenant {tenant_id}")
raise e
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def assign_tenant_to_user(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
"""
Assign a tenant to a user and perform necessary operations.
Uses transaction handling to ensure atomicity and includes retry logic
for control plane notifications.
"""
# First, add the user to the tenant in a transaction
try:
add_users_to_tenant([email], tenant_id)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")

View File

@@ -74,3 +74,21 @@ def drop_schema(tenant_id: str) -> None:
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)
def get_current_alembic_version(tenant_id: str) -> str:
"""Get the current Alembic version for a tenant."""
from alembic.runtime.migration import MigrationContext
from sqlalchemy import text
engine = get_sqlalchemy_engine()
# Set the search path to the tenant's schema
with engine.connect() as connection:
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Get the current version from the alembic_version table
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
return current_rev or "head"

View File

@@ -0,0 +1,67 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -0,0 +1,39 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -0,0 +1,90 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,27 +1,56 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = logging.getLogger(__name__)
logger = setup_logger()
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
@@ -38,13 +67,56 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
# Start a transaction
db_session.begin()
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Check if the user already has a mapping to this tenant
existing_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.first()
)
# If user already has an active mapping, add this one as inactive
if not existing_mapping:
# Check if the user already has an active mapping to any tenant
has_active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=False if has_active_mapping else True,
)
)
# Commit the transaction
db_session.commit()
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
db_session.rollback()
raise
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
@@ -76,3 +148,187 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

BIN
backend/hello-vmlinux.bin Normal file

Binary file not shown.

View File

@@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"

View File

@@ -1,11 +1,14 @@
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.onyx_torch_model import ConnectorClassifier
from model_server.onyx_torch_model import HybridClassifier
@@ -13,11 +16,22 @@ from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
)
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
@@ -31,6 +45,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
@@ -85,7 +103,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer:
def get_local_intent_model(
model_name_or_path: str = INTENT_MODEL_VERSION,
tag: str = INTENT_MODEL_TAG,
tag: str | None = INTENT_MODEL_TAG,
) -> HybridClassifier:
global _INTENT_MODEL
if _INTENT_MODEL is None:
@@ -102,7 +120,9 @@ def get_local_intent_model(
try:
# Attempt to download the model snapshot
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
except Exception as e:
logger.error(
@@ -112,6 +132,44 @@ def get_local_intent_model(
return _INTENT_MODEL
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> SetFitModel:
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
logger.notice(
f"Loading content information model from local cache: {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
logger.notice(
f"Loaded content information model from local cache: {local_path}"
)
except Exception as e:
logger.warning(f"Failed to load content information model directly: {e}")
try:
# Attempt to download the model snapshot
logger.notice(
f"Downloading content information model snapshot for {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
except Exception as e:
logger.error(
f"Failed to load content information model even after attempted snapshot download: {e}"
)
raise
return _INFORMATION_CONTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
@@ -195,6 +253,13 @@ def warm_up_intent_model() -> None:
)
def warm_up_information_content_model() -> None:
logger.notice("Warming up Content Model") # TODO: add version if needed
information_content_model = get_local_information_content_model()
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
@@ -218,6 +283,117 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
return intent_probabilities.tolist(), token_positive_probs
@simple_log_function_time()
def run_content_classification_inference(
text_inputs: list[str],
) -> list[ContentClassificationPrediction]:
"""
Assign a score to the segments in question. The model stored in get_local_information_content_model()
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
In the code outside of the model/inference model servers that score will be converted into the actual
boost factor.
"""
def _prob_to_score(prob: float) -> float:
"""
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
"""
_MIN_BASE_SCORE = 0.25
_MAX_BASE_SCORE = 0.75
if prob < _MIN_BASE_SCORE:
raw_score = 0.0
elif prob < _MAX_BASE_SCORE:
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
else:
raw_score = 1.0
return (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
+ (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
* raw_score
)
_BATCH_SIZE = 32
content_model = get_local_information_content_model()
# Process inputs in batches
all_output_classes: list[int] = []
all_base_output_probabilities: list[float] = []
for i in range(0, len(text_inputs), _BATCH_SIZE):
batch = text_inputs[i : i + _BATCH_SIZE]
batch_with_prefix = []
batch_indices = []
# Pre-allocate results for this batch
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
# Pre-process batch to handle long input exceptions
for j, text in enumerate(batch):
if len(text) == 0:
# if no input, treat as non-informative from the model's perspective
batch_output_classes[j] = np.array(0)
batch_probabilities[j] = np.array(0.0)
logger.warning("Input for Content Information Model is empty")
elif (
len(text.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
# if input is short, use the model
batch_with_prefix.append(
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
)
batch_indices.append(j)
else:
# if longer than cutoff, treat as informative (stay with default), but issue warning
logger.warning("Input for Content Information Model too long")
if batch_with_prefix: # Only run model if we have valid inputs
# Get predictions for the batch
model_output_classes = content_model(batch_with_prefix)
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
# Place results in the correct positions
for idx, batch_idx in enumerate(batch_indices):
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
batch_probabilities[batch_idx] = model_output_probabilities[idx][
1
].numpy() # x[1] is prob of the positive class
all_output_classes.extend([int(x) for x in batch_output_classes])
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
logits = [
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
for p in all_base_output_probabilities
]
scaled_logits = [
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
for logit in logits
]
output_probabilities_with_temp = [
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
for scaled_logit in scaled_logits
]
prediction_scores = [
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
]
content_classification_predictions = [
ContentClassificationPrediction(
predicted_label=predicted_label, content_boost_factor=output_score
)
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
]
return content_classification_predictions
def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
@@ -362,3 +538,10 @@ async def process_analysis_request(
is_keyword, keywords = run_analysis(intent_request)
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
@router.post("/content-classification")
async def process_content_classification_request(
content_classification_requests: list[str],
) -> list[ContentClassificationPrediction]:
return run_content_classification_inference(content_classification_requests)

View File

@@ -62,6 +62,60 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -92,31 +146,17 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
logger.error(error_string)
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -155,7 +195,6 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -239,22 +278,51 @@ class CloudEmbedding:
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
@staticmethod
def create(
@@ -569,6 +637,13 @@ async def process_embed_request(
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,

View File

@@ -13,6 +13,7 @@ from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
@@ -64,19 +65,31 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
app.state.gpu_type = gpu_type
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
try:
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
except Exception as e:
logger.warning(
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
"This is not a critical error and the model server will continue to run."
)
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
else:
logger.notice("This model server should only run document indexing.")
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
yield

View File

@@ -18,7 +18,7 @@ def _get_access_for_document(
document_id=document_id,
)
return DocumentAccess.build(
doc_access = DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
@@ -26,6 +26,8 @@ def _get_access_for_document(
is_public=info[2] if info else False,
)
return doc_access
def get_access_for_document(
document_id: str,
@@ -38,12 +40,12 @@ def get_access_for_document(
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
return DocumentAccess.build(
user_emails=[],
user_groups=[],
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
external_user_emails=[],
external_user_group_ids=[],
)
@@ -55,19 +57,18 @@ def _get_access_for_documents(
db_session=db_session,
document_ids=document_ids,
)
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
doc_access = {}
for document_id, user_emails, is_public in document_access_info:
doc_access[document_id] = DocumentAccess.build(
user_emails=[email for email in user_emails if email],
# MIT version will wipe all groups and external groups on update
user_groups=set(),
user_groups=[],
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
external_user_emails=[],
external_user_group_ids=[],
)
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# Sometimes the document has not been indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.

View File

@@ -20,7 +20,7 @@ class ExternalAccess:
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
together. It's used for syncing document permissions to Vespa.
"""
external_access: ExternalAccess
@@ -56,34 +56,46 @@ class DocExternalAccess:
)
@dataclass(frozen=True)
@dataclass(frozen=True, init=False)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
external_user_emails: set[str]
external_user_group_ids: set[str]
is_public: bool
def __init__(self) -> None:
raise TypeError(
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:
acl_set.add(prefix_user_email(user_email))
for group_name in self.user_groups:
acl_set.add(prefix_user_group(group_name))
for external_user_email in self.external_user_emails:
acl_set.add(prefix_user_email(external_user_email))
for external_group_id in self.external_user_group_ids:
acl_set.add(prefix_external_group(external_group_id))
if self.is_public:
acl_set.add(PUBLIC_DOC_PAT)
return acl_set
@classmethod
def build(
cls,
@@ -93,29 +105,32 @@ class DocumentAccess(ExternalAccess):
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
obj = object.__new__(cls)
object.__setattr__(
obj, "user_emails", {user_email for user_email in user_emails if user_email}
)
object.__setattr__(obj, "user_groups", set(user_groups))
object.__setattr__(
obj,
"external_user_emails",
{external_email for external_email in external_user_emails},
)
object.__setattr__(
obj,
"external_user_group_ids",
{external_group_id for external_group_id in external_user_group_ids},
)
object.__setattr__(obj, "is_public", is_public)
return obj
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
default_public_access = DocumentAccess.build(
external_user_emails=[],
external_user_group_ids=[],
user_emails=[],
user_groups=[],
is_public=True,
)

View File

@@ -7,7 +7,6 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
@@ -24,7 +23,7 @@ def process_llm_stream(
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")

View File

@@ -31,6 +31,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
@@ -92,6 +93,7 @@ def check_sub_answer(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)

View File

@@ -46,6 +46,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -119,6 +120,7 @@ def generate_sub_answer(
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -62,6 +63,7 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
@@ -153,8 +155,8 @@ def generate_initial_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -278,6 +280,9 @@ def generate_initial_answer(
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -34,6 +34,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
@@ -141,6 +142,7 @@ def decompose_orig_question(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),

View File

@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
@@ -112,6 +113,7 @@ def compare_answers(
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
@@ -144,6 +145,7 @@ def create_refined_sub_questions(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),

View File

@@ -50,13 +50,7 @@ def decide_refinement_need(
)
]
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)

View File

@@ -21,6 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
@@ -96,6 +97,7 @@ def extract_entities_terms(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -46,6 +46,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -68,6 +69,8 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
@@ -179,8 +182,8 @@ def generate_validate_refined_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -302,7 +305,11 @@ def generate_validate_refined_answer(
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -409,6 +416,7 @@ def generate_validate_refined_answer(
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),

View File

@@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
if result.query_info is not None:
query_info = result.query_info
break
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)
assert query_info is not None, "must have query info"
return query_info

View File

@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
@@ -96,6 +97,7 @@ def expand_queries(
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)

View File

@@ -56,8 +56,8 @@ def format_results(
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
get_retrieved_sections=lambda: reranked_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
pre_rerank_docs = callback_container[0] if callback_container else []
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@@ -25,6 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -93,6 +94,7 @@ def verify_documents(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
assert isinstance(response.content, str)

View File

@@ -44,7 +44,9 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)

View File

@@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -25,6 +34,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
@@ -37,6 +47,31 @@ def choose_tool(
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.search_request.query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.search_request.query,
)
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
@@ -47,7 +82,6 @@ def choose_tool(
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@@ -71,11 +105,22 @@ def choose_tool(
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
@@ -153,10 +198,22 @@ def choose_tool(
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -9,18 +9,21 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
@@ -50,11 +53,11 @@ def basic_use_tool_response(
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(section_to_llm_doc(section))
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -2,6 +2,7 @@ from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"

View File

@@ -42,6 +42,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
@@ -61,6 +62,7 @@ from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
@@ -319,8 +321,10 @@ def dispatch_separated(
sep: str = DISPATCH_SEP_CHAR,
) -> list[BaseMessage_Content]:
num = 1
accumulated_tokens = ""
streamed_tokens: list[BaseMessage_Content] = []
for token in tokens:
accumulated_tokens += cast(str, token.content)
content = cast(str, token.content)
if sep in content:
sub_question_parts = content.split(sep)
@@ -402,6 +406,7 @@ def summarize_history(
llm.invoke,
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - summarize history")
@@ -505,3 +510,9 @@ def get_deduplicated_structured_subquestion_documents(
cited_documents=dedup_inference_section_list(cited_docs),
context_documents=dedup_inference_section_list(context_docs),
)
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
return not (
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
)

View File

@@ -1,5 +1,6 @@
import smtplib
from datetime import datetime
from email.mime.image import MIMEImage
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
@@ -13,10 +14,16 @@ from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.file import FileWithMimeType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
<html lang="en">
@@ -56,6 +63,11 @@ HTML_EMAIL_TEMPLATE = """\
}}
.header img {{
max-width: 140px;
width: 140px;
height: auto;
filter: brightness(1.1) contrast(1.2);
border-radius: 8px;
padding: 5px;
}}
.body-content {{
padding: 20px 30px;
@@ -72,12 +84,16 @@ HTML_EMAIL_TEMPLATE = """\
}}
.cta-button {{
display: inline-block;
padding: 12px 20px;
background-color: #000000;
padding: 14px 24px;
background-color: #0055FF;
color: #ffffff !important;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
font-weight: 600;
font-size: 16px;
margin-top: 10px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
text-align: center;
}}
.footer {{
font-size: 13px;
@@ -97,8 +113,8 @@ HTML_EMAIL_TEMPLATE = """\
<td class="header">
<img
style="background-color: #ffffff; border-radius: 8px;"
src="https://www.onyx.app/logos/customer/onyx.png"
alt="Onyx Logo"
src="cid:logo.png"
alt="{application_name} Logo"
>
</td>
</tr>
@@ -113,9 +129,8 @@ HTML_EMAIL_TEMPLATE = """\
</tr>
<tr>
<td class="footer">
© {year} Onyx. All rights reserved.
<br>
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
© {year} {application_name}. All rights reserved.
{slack_fragment}
</td>
</tr>
</table>
@@ -125,17 +140,27 @@ HTML_EMAIL_TEMPLATE = """\
def build_html_email(
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
application_name: str | None,
heading: str,
message: str,
cta_text: str | None = None,
cta_link: str | None = None,
) -> str:
slack_fragment = ""
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
if cta_text and cta_link:
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
else:
cta_block = ""
return HTML_EMAIL_TEMPLATE.format(
application_name=application_name,
title=heading,
heading=heading,
message=message,
cta_block=cta_block,
slack_fragment=slack_fragment,
year=datetime.now().year,
)
@@ -146,22 +171,44 @@ def send_email(
html_body: str,
text_body: str,
mail_from: str = EMAIL_FROM,
inline_png: tuple[str, bytes] | None = None,
) -> None:
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
msg["From"] = mail_from
if mail_from:
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
# Add text part first (lowest priority)
text_part = MIMEText(text_body, "plain")
msg.attach(text_part)
msg.attach(part_text)
msg.attach(part_html)
if inline_png:
# For HTML with images, create a multipart/related container
related = MIMEMultipart("related")
# Add the HTML part to the related container
html_part = MIMEText(html_body, "html")
related.attach(html_part)
# Add image with proper Content-ID to the related container
img = MIMEImage(inline_png[1], _subtype="png")
img.add_header("Content-ID", f"<{inline_png[0]}>")
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
related.attach(img)
# Add the related part to the message (higher priority than text)
msg.attach(related)
else:
# No images, just add HTML directly (higher priority than text)
html_part = MIMEText(html_body, "html")
msg.attach(html_part)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
@@ -173,8 +220,21 @@ def send_email(
def send_subscription_cancellation_email(user_email: str) -> None:
"""This is templated but isn't meaningful for whitelabeling."""
# Example usage of the reusable HTML
subject = "Your Onyx Subscription Has Been Canceled"
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Your {application_name} Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>We're sorry to see you go.</p>"
@@ -183,23 +243,48 @@ def send_subscription_cancellation_email(user_email: str) -> None:
)
cta_text = "Renew Subscription"
cta_link = "https://www.onyx.app/pricing"
html_content = build_html_email(heading, message, cta_text, cta_link)
html_content = build_html_email(
application_name,
heading,
message,
cta_text,
cta_link,
)
text_content = (
"We're sorry to see you go.\n"
"Your subscription has been canceled and will end on your next billing date.\n"
"If you change your mind, visit https://www.onyx.app/pricing"
)
send_email(user_email, subject, html_content, text_content)
send_email(
user_email,
subject,
html_content,
text_content,
inline_png=("logo.png", onyx_file.data),
)
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
subject = "Invitation to Join Onyx Organization"
onyx_file: FileWithMimeType | None = None
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Invitation to Join {application_name} Organization"
heading = "You've Been Invited!"
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
message = f"<p>You have been invited by {current_user.email} to join an organization on {application_name}.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
@@ -225,19 +310,32 @@ def send_user_email_invite(
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
html_content = build_html_email(
application_name,
heading,
message,
cta_text,
cta_link,
)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
f"You have been invited by {current_user.email} to join an organization on {application_name}.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)
send_email(
user_email,
subject,
html_content,
text_content,
inline_png=("logo.png", onyx_file.data),
)
def send_forgot_password_email(
@@ -247,27 +345,80 @@ def send_forgot_password_email(
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)
text_content = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Reset Your {application_name} Password"
heading = "Reset Your Password"
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
cta_text = "Reset Password"
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
html_content = build_html_email(
application_name,
heading,
message,
cta_text,
cta_link,
)
text_content = (
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
)
send_email(
user_email,
subject,
html_content,
text_content,
mail_from,
inline_png=("logo.png", onyx_file.data),
)
def send_user_verification_email(
user_email: str,
token: str,
new_organization: bool = False,
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
subject = "Onyx Email Verification"
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"{application_name} Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
if new_organization:
link = add_url_params(link, {"first_user": "true"})
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)
html_content = build_html_email("Verify Your Email", message)
html_content = build_html_email(
application_name,
"Verify Your Email",
message,
)
text_content = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
send_email(
user_email,
subject,
html_content,
text_content,
mail_from,
inline_png=("logo.png", onyx_file.data),
)

View File

@@ -1,5 +1,6 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -18,3 +19,17 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -0,0 +1,211 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Standard OAuth refresh token endpoints
REFRESH_ENDPOINTS = {
"google": "https://oauth2.googleapis.com/token",
}
# NOTE: Keeping this as a utility function for potential future debugging,
# but not using it in production code
async def _test_expire_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10,
) -> bool:
"""
Utility function for testing - Sets an OAuth token to expire in a short time
to facilitate testing of the refresh flow.
Not used in production code.
"""
try:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
)
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
return True
except Exception as e:
logger.exception(f"Error setting artificial expiration: {str(e)}")
return False
async def refresh_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> bool:
"""
Attempt to refresh an OAuth token that's about to expire or has expired.
Returns True if successful, False otherwise.
"""
if not oauth_account.refresh_token:
logger.warning(
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
)
return False
provider = oauth_account.oauth_name
if provider not in REFRESH_ENDPOINTS:
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
return False
try:
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
async with httpx.AsyncClient() as client:
response = await client.post(
REFRESH_ENDPOINTS[provider],
data={
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
"refresh_token": oauth_account.refresh_token,
"grant_type": "refresh_token",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
logger.error(
f"Failed to refresh OAuth token: Status {response.status_code}"
)
return False
token_data = response.json()
new_access_token = token_data.get("access_token")
new_refresh_token = token_data.get(
"refresh_token", oauth_account.refresh_token
)
expires_in = token_data.get("expires_in")
# Calculate new expiry time if provided
new_expires_at: Optional[int] = None
if expires_in:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expires_in)
)
# Update the OAuth account
updated_data: Dict[str, Any] = {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}
if new_expires_at:
updated_data["expires_at"] = new_expires_at
# Update oidc_expiry in user model if we're tracking it
if TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(
new_expires_at, tz=timezone.utc
)
await user_manager.user_db.update(
user, {"oidc_expiry": oidc_expiry}
)
# Update the OAuth account
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")
return True
except Exception as e:
logger.exception(f"Error refreshing OAuth token: {str(e)}")
return False
async def check_and_refresh_oauth_tokens(
user: User,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> None:
"""
Check if any OAuth tokens are expired or about to expire and refresh them.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return
now_timestamp = datetime.now(timezone.utc).timestamp()
# Buffer time to refresh tokens before they expire (in seconds)
buffer_seconds = 300 # 5 minutes
for oauth_account in user.oauth_accounts:
# Skip accounts without refresh tokens
if not oauth_account.refresh_token:
continue
# If token is about to expire, refresh it
if (
oauth_account.expires_at
and oauth_account.expires_at - now_timestamp < buffer_seconds
):
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
success = await refresh_oauth_token(
user, oauth_account, db_session, user_manager
)
if not success:
logger.warning(
"Failed to refresh OAuth token. User may need to re-authenticate."
)
async def check_oauth_account_has_refresh_token(
user: User,
oauth_account: OAuthAccount,
) -> bool:
"""
Check if an OAuth account has a refresh token.
Returns True if a refresh token exists, False otherwise.
"""
return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
"""
Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return []
accounts_needing_refresh = []
for oauth_account in user.oauth_accounts:
has_refresh_token = await check_oauth_account_has_refresh_token(
user, oauth_account
)
if not has_refresh_token:
accounts_needing_refresh.append(oauth_account)
return accounts_needing_refresh

View File

@@ -5,12 +5,16 @@ import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import TypeVar
import jwt
from email_validator import EmailNotValidError
@@ -52,6 +56,7 @@ from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
@@ -100,10 +105,12 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
@@ -354,7 +361,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reason="Password must contain at least one special character from the following set: "
f"{PASSWORD_SPECIAL_CHARS}."
)
return
async def oauth_callback(
@@ -508,6 +514,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def on_after_login(
self,
user: User,
request: Optional[Request] = None,
response: Optional[Response] = None,
) -> None:
try:
if response and request and ANONYMOUS_USER_COOKIE_NAME in request.cookies:
response.delete_cookie(
ANONYMOUS_USER_COOKIE_NAME,
# Ensure cookie deletion doesn't override other cookies by setting the same path/domain
path="/",
domain=None,
secure=WEB_DOMAIN.startswith("https"),
)
logger.debug(f"Deleted anonymous user cookie for user {user.email}")
except Exception:
logger.exception("Error deleting anonymous user cookie")
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
@@ -579,22 +604,30 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
logger.notice(
f"Verification requested for user {user.id}. Verification token: {token}"
)
send_user_verification_email(user.email, token)
user_count = await get_user_count()
send_user_verification_email(
user.email, token, new_organization=user_count == 1
)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
)
tenant_id: str | None = None
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
POSTGRES_DEFAULT_SCHEMA,
)(
email=email,
)
except Exception as e:
logger.warning(
f"User attempted to login with invalid credentials: {str(e)}"
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -680,16 +713,20 @@ cookie_transport = CookieTransport(
)
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
T = TypeVar("T", covariant=True)
ID = TypeVar("ID", contravariant=True)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
# Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: Optional[str], user: Any) -> str:
"""
Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token.
"""
...
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
@@ -748,6 +785,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by extending its expiration time in Redis."""
if token is None:
# If no token provided, create a new one
return await self.write_token(user)
redis = await get_async_redis_connection()
token_key = f"{self.key_prefix}{token}"
# Check if token exists
token_data_str = await redis.get(token_key)
if not token_data_str:
# Token not found, create new one
return await self.write_token(user)
# Token exists, extend its lifetime
token_data = json.loads(token_data_str)
await redis.set(
token_key,
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities."""
def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database."""
if token is None:
return await self.write_token(user)
# Find the token in database
access_token = await self._access_token_db.get_by_token(token)
if access_token is None:
# Token not found, create new one
return await self.write_token(user)
# Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta(
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
)
await self._access_token_db.update(access_token, {"expires": new_expires})
return token
def get_redis_strategy() -> TenantAwareRedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> RefreshableDatabaseStrategy:
return RefreshableDatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
@@ -798,6 +904,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
return router
def get_refresh_router(
self,
backend: AuthenticationBackend,
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
) -> APIRouter:
"""
Provide a router for session token refreshing.
"""
# Import the oauth_refresher here to avoid circular imports
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
refresh_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
"description": "Missing token or inactive user."
}
},
**backend.transport.get_openapi_login_responses_success(),
}
@router.post(
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
)
async def refresh(
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
get_user_manager
),
db_session: AsyncSession = Depends(get_async_session),
) -> Response:
try:
user, token = user_token
logger.info(f"Processing token refresh request for user {user.email}")
# Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens(
user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
)
# Check if strategy supports refreshing
supports_refresh = hasattr(strategy, "refresh_token") and callable(
getattr(strategy, "refresh_token")
)
if supports_refresh:
try:
refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info(
f"Successfully refreshed session token for user {user.email}"
)
return await backend.transport.get_login_response(new_token)
except Exception as e:
logger.error(f"Error refreshing session token: {str(e)}")
# Fallback to logout and login if refresh fails
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
# Fallback: logout and login again
logger.info(
"Strategy doesn't support refresh - using logout/login flow"
)
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
except Exception as e:
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Token refresh failed: {str(e)}",
)
return router
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
@@ -888,7 +1076,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accesssible_user(
async def current_chat_accessible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1031,12 +1219,20 @@ def get_oauth_router(
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
authorization_url = add_url_params(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(
@@ -1089,6 +1285,12 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1122,11 +1324,22 @@ def get_oauth_router(
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
redirect_response = RedirectResponse(next_url, status_code=302)
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
for header_name, header_value in response.headers.items():
redirect_response.headers[header_name] = header_value
# FastAPI can have multiple Set-Cookie headers as a list
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
for cookie_value in header_value:
redirect_response.headers.append(header_name, cookie_value)
else:
redirect_response.headers[header_name] = header_value
if hasattr(response, "body"):
redirect_response.body = response.body
@@ -1134,6 +1347,7 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -1,5 +1,6 @@
import logging
import multiprocessing
import os
import time
from typing import Any
from typing import cast
@@ -34,7 +35,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -225,7 +225,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -306,12 +306,12 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
logger.info(f"Running as a secondary celery worker: pid={os.getpid()}")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")

View File

@@ -1,6 +1,5 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
current_schedule = self.schedule.items()
# get potential new state
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
beat_multiplier = OnyxRuntime.get_beat_multiplier()
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)

View File

@@ -0,0 +1,7 @@
from celery import Celery
import onyx.background.celery.apps.app_base as app_base
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.client")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]

View File

@@ -111,5 +111,8 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -92,5 +92,6 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -1,4 +1,5 @@
import logging
import os
from typing import Any
from typing import cast
@@ -38,10 +39,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -94,7 +96,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
logger.info("Running as the primary celery worker.")
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
# Less startup checks in multi-tenant case
if MULTI_TENANT:
@@ -102,7 +104,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -173,6 +175,9 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
logger.exception(
f"Marking attempt {attempt.id} as canceled due to validation error 2"
)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@@ -235,7 +240,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
@@ -284,5 +289,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_folder_sync",
]
)

View File

@@ -0,0 +1,16 @@
import onyx.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late

View File

@@ -0,0 +1,73 @@
# backend/onyx/background/celery/memory_monitoring.py
import logging
import os
from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()
cpu_percent = process.cpu_percent(interval=0.1)
# Build metadata string from additional_metadata dictionary
metadata_str = " ".join(
[f"{key}={value}" for key, value in additional_metadata.items()]
)
metadata_str = f" {metadata_str}" if metadata_str else ""
memory_logger.info(
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
f"cpu={cpu_percent:.2f}{metadata_str}"
)
except Exception:
logger.exception("Error monitoring process memory.")

View File

@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []
@@ -63,6 +64,15 @@ beat_task_templates.extend(
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-user-file-folder-sync",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
@@ -167,6 +177,16 @@ beat_cloud_tasks: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(minutes=10),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that only run self hosted
@@ -184,6 +204,16 @@ if not MULTI_TENANT:
"queue": OnyxCeleryQueues.MONITORING,
},
},
{
"name": "monitor-process-memory",
"task": OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
)

View File

@@ -30,6 +30,9 @@ from onyx.db.connector_credential_pair import (
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import (
delete_all_documents_by_connector_credential_pair__no_commit,
)
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_current_tenant
@@ -386,6 +389,8 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
credential_id_to_delete: int | None = None
connector_id_to_delete: int | None = None
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -440,16 +445,35 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# Flush to ensure document deletion happens before connector deletion
db_session.flush()
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
connector_id=connector_id_to_delete,
)
if not connector or not len(connector.credentials):
task_logger.info(
@@ -482,15 +506,15 @@ def monitor_connector_deletion_taskset(
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"connector={connector_id_to_delete} "
f"credential={credential_id_to_delete} "
f"docs_deleted={fence_data.num_tasks}"
)
@@ -540,7 +564,7 @@ def validate_connector_deletion_fences(
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
queued_upsert_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
@@ -627,7 +651,7 @@ def validate_connector_deletion_fence(
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
if member_str in queued_upsert_tasks:
continue
tasks_not_in_celery += 1

View File

@@ -17,6 +17,7 @@ from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
@@ -46,7 +47,6 @@ from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
@@ -64,11 +64,14 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
@@ -105,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
@@ -285,7 +289,7 @@ def try_creating_permissions_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
# fill in the celery task id
@@ -420,12 +424,7 @@ def connector_permission_sync_generator_task(
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
# TODO: add some notification to the admins here
raise
source_type = cc_pair.connector.source
@@ -453,23 +452,23 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
cc_pair, callback
)
document_external_accesses = doc_sync_func(cc_pair, callback)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
if tasks_generated is None:
return None
tasks_generated = 0
for doc_external_access in document_external_accesses:
redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=[doc_external_access],
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
tasks_generated += 1
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
@@ -881,6 +880,18 @@ def monitor_ccpair_permissions_taskset(
f"remaining={remaining} "
f"initial={initial}"
)
# Add telemetry for permission syncing progress
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
data={
"cc_pair_id": cc_pair_id,
"total_docs_synced": initial if initial is not None else 0,
"remaining_docs_to_sync": remaining,
},
tenant_id=tenant_id,
)
if remaining > 0:
return
@@ -892,6 +903,13 @@ def monitor_ccpair_permissions_taskset(
f"num_synced={initial}"
)
# Add telemetry for permission syncing complete
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
data={"cc_pair_id": cc_pair_id},
tenant_id=tenant_id,
)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,

View File

@@ -41,7 +41,6 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -272,7 +271,7 @@ def try_creating_external_group_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
payload.celery_task_id = result.id
@@ -402,12 +401,7 @@ def connector_external_group_sync_generator_task(
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
# TODO: add some notification to the admins here
raise
source_type = cc_pair.connector.source
@@ -425,12 +419,9 @@ def connector_external_group_sync_generator_task(
try:
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
)
raise e

View File

@@ -23,6 +23,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
@@ -71,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -363,6 +365,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
task_logger.warning("check_for_indexing - Starting")
tasks_created = 0
locked = False
@@ -400,7 +403,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
logger.warning(f"Adding {key_bytes} to the lookup table.")
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
redis_client.set(
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
1,
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
)
# 1/3: KICKOFF
@@ -427,7 +434,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
cc_pairs = fetch_connector_credential_pairs(db_session)
cc_pairs = fetch_connector_credential_pairs(
db_session, include_user_files=True
)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
@@ -446,12 +455,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
not search_settings_instance.status.is_current()
and not search_settings_instance.background_reindex_enabled
):
task_logger.warning("SKIPPING DUE TO NON-LIVE SEARCH SETTINGS")
continue
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
if redis_connector_index.fenced:
task_logger.info(
f"check_for_indexing - Skipping fenced connector: "
f"cc_pair={cc_pair_id} search_settings={search_settings_instance.id}"
)
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -459,6 +474,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"check_for_indexing - CC pair not found: cc_pair={cc_pair_id}"
)
continue
last_attempt = get_last_attempt_for_cc_pair(
@@ -472,7 +490,20 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
task_logger.info(
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
continue
else:
task_logger.info(
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
reindex = False
if search_settings_instance.status.is_current():
@@ -511,6 +542,12 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
else:
task_logger.info(
f"Failed to create indexing task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id}"
)
lock_beat.reacquire()
@@ -984,6 +1021,9 @@ def connector_indexing_proxy_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
@@ -1024,6 +1064,23 @@ def connector_indexing_proxy_task(
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
@@ -1123,6 +1180,9 @@ def connector_indexing_proxy_task(
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
)
mark_attempt_canceled(
index_attempt_id,
db_session,
@@ -1170,6 +1230,7 @@ def connector_indexing_proxy_task(
return
# primary
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
@@ -1217,6 +1278,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
)
# light worker
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,

View File

@@ -371,6 +371,7 @@ def should_index(
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
print(f"Not indexing cc_pair={cc_pair.id}: NOT_APPLICABLE source")
return False
# User can still manually create single indexing attempts via the UI for the
@@ -380,6 +381,9 @@ def should_index(
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
print(
f"Not indexing cc_pair={cc_pair.id}: DISABLE_INDEX_UPDATE_ON_SWAP is True and secondary index building"
)
return False
# When switching over models, always index at least once
@@ -388,19 +392,31 @@ def should_index(
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
)
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
)
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
)
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
)
return False
return True
@@ -412,6 +428,9 @@ def should_index(
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
print(
f"Not indexing cc_pair={cc_pair.id}: Connector is paused or is Ingestion API"
)
return False
if search_settings_instance.status.is_current():
@@ -424,11 +443,16 @@ def should_index(
return True
if connector.refresh_freq is None:
print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
print(
f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
)
return False
return True
@@ -508,6 +532,13 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# Determine which queue to use based on whether this is a user file
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_INDEXING
)
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
@@ -518,7 +549,7 @@ def try_creating_indexing_task(
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
queue=queue,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)

Some files were not shown because too many files have changed in this diff Show More