Compare commits

..

135 Commits

Author SHA1 Message Date
pablonyx
aacdf775da add basic user invite flow 2025-03-11 11:14:52 -07:00
pablonyx
59a388ce0a fix tests 2025-03-11 11:12:35 -07:00
rkuo-danswer
9cd3cbb978 fix versions (#4250)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-10 23:50:07 -07:00
pablonyx
ab1b6b487e descrease model server logspam (#4166) 2025-03-10 18:29:27 +00:00
Chris Weaver
6ead9510a4 Small notion tweaks (#4244)
* Small notion tweaks

* Add comment
2025-03-10 15:51:12 +00:00
Chris Weaver
965f9e98bf Eliminate extremely long log line for large checkpointds (#4236)
* Eliminate extremely long log line for large checkpointds

* address greptile
2025-03-10 15:50:50 +00:00
rkuo-danswer
426883bbf5 Feature/agentic buffered (#4231)
* rename agent test script to prevent pytest autodiscovery

* first cut

* fix log message

* fix up typing

* add a sample test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-10 15:48:42 +00:00
rkuo-danswer
6ca400ced9 Bugfix/delete document tags slow (#4232)
* Add Missing Date and Message-ID Headers to Ensure Email Delivery

* fix issue Performance issue during connector deletion #4191

* fix ruff

* bump to rebuild PR

---------

Co-authored-by: ThomaciousD <2194608+ThomaciousD@users.noreply.github.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-10 03:07:30 +00:00
Weves
104c4b9f4d small modal improvement 2025-03-09 20:54:53 -07:00
pablonyx
8b5e8bd5b9 k (#4240) 2025-03-10 03:06:13 +00:00
Weves
7f7621d7c0 SMall gitbook tweaks 2025-03-09 14:46:44 -07:00
pablonyx
06dcc28d05 Improved login experience (#4178)
* functional initial auth modal

* k

* k

* k

* looking good

* k

* k

* k

* k

* update

* k

* k

* misc bunch

* improvements

* k

* address comments

* k

* nit

* update

* k
2025-03-09 01:06:20 +00:00
pablonyx
18df63dfd9 Fix local background jobs (#4241) 2025-03-08 14:47:56 -08:00
Chris Weaver
0d3c72acbf Add basic memory logging (#4234)
* Add basic memory logging

* Small tweaks

* Switch to monotonic
2025-03-08 03:49:47 +00:00
rkuo-danswer
9217243e3e Bugfix/query history notes (#4204)
* early work in progress

* rename utility script

* move actual data seeding to a shareable function

* add test

* make the test pass with the fix

* fix comment

* slight improvements and notes to query history and seeding

* update test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-07 19:52:30 +00:00
rkuo-danswer
61ccba82a9 light worker needs to discover some indexing tasks (#4209)
* light worker needs to discover some indexing tasks

* fix formatting

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-07 11:52:09 -08:00
Weves
9e8eba23c3 Fix frozen model issue 2025-03-07 09:05:43 -08:00
evan-danswer
0c29743538 use max_tokens to do better rate limit handling (#4224)
* use max_tokens to do better rate limit handling

* fix unti tests

* address greptile comment, thanks greptile
2025-03-06 18:12:05 -08:00
pablonyx
08b2421947 fix 2025-03-06 17:30:31 -08:00
pablonyx
ed518563db minor typing update 2025-03-06 17:02:39 -08:00
pablonyx
a32f7dc936 Fix Connector tests (confluence) (#4221) 2025-03-06 17:00:01 -08:00
rkuo-danswer
798e10c52f revert to always building model server (#4213)
* revert to always building model server

* fix just in case

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 23:49:45 +00:00
pablonyx
bf4983e35a Ensure consistent UX (#4222)
* ux consistent

* nit

* Update web/src/app/admin/configuration/llm/interfaces.ts

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-03-06 23:13:32 +00:00
evan-danswer
b7da91e3ae improved basic search latency (#4186)
* improved basic search latency

* address PR comments + minor cleanup
2025-03-06 22:22:59 +00:00
Weves
29382656fc Stop trying a million times for the user validity check 2025-03-06 15:35:49 -08:00
pablonyx
7d6db8d500 Comma separated list for Github repos (#4199) 2025-03-06 14:46:57 -08:00
Chris Weaver
a7a374dc81 Confluence fixes (#4220)
* Confluence fixes

* Small tweak

* Address greptile comments
2025-03-06 20:57:07 +00:00
rkuo-danswer
facc8cc2fa add scope needed for permission sync (#4198)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 20:03:38 +00:00
rkuo-danswer
2c0af0a0ca Feature/helm updates (#4201)
* add ingress for api and web

* helm setup docs

* add letsencrypt. close blocks

* use pathType ImplementationSpecific as Prefix is deprecated

* fix backend labels. configure nginx routes. update annotations

* fix linting

---------

Co-authored-by: Sajjad Anwar <sajjadkm@gmail.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 19:48:20 +00:00
pablonyx
bfbc1cd954 k (#4172) 2025-03-06 18:55:12 +00:00
pablonyx
626da583aa Fix gated tenants (#4177)
* fix

* mypy .
2025-03-06 18:07:15 +00:00
pablonyx
92faca139d Fix extra tenant mystery (#4197)
* fix extra tenant mystery

* nit
2025-03-06 18:06:49 +00:00
pablonyx
cec05c5ee9 Revert "k"
This reverts commit 687122911d.
2025-03-06 09:38:31 -08:00
Richard Kuo (Danswer)
eaf054ef06 oauth router went missing? 2025-03-05 15:50:23 -08:00
pablonyx
a7a1a24658 minor nit 2025-03-05 15:35:02 -08:00
pablonyx
687122911d k 2025-03-05 15:27:14 -08:00
pablonyx
40953bd4fe Workspace configs (#4202) 2025-03-05 12:28:44 -08:00
rkuo-danswer
a7acc07e79 fix usage report pagination (#4183)
* early work in progress

* rename utility script

* move actual data seeding to a shareable function

* add test

* make the test pass with the fix

* fix comment

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-05 19:13:51 +00:00
pablonyx
b6e9e65bb8 * Replaces Amazon and Anthropic Icons with version better suitable fo… (#4190)
* * Replaces Amazon and Anthropic Icons with version better suitable for both Dark and  Light modes;
* Adds icon for DeepSeek;
* Simplify logic on icon selection;
* Adds entries for Phi-4, Claude 3.7, Ministral and Gemini 2.0 models

* nit

* k

* k

---------

Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
2025-03-05 17:57:39 +00:00
pablonyx
20f2b9b2bb Add image support for search (#4090)
* add support for image search

* quick fix up

* k

* k

* k

* k

* nit

* quick fix for connector tests
2025-03-05 17:44:18 +00:00
Chris Weaver
f731beca1f Add ONYX_QUERY_HISTORY_TYPE to the dev compose files (#4196) 2025-03-05 17:34:55 +00:00
Weves
fe246aecbb Attempt to address tool happy claude 2025-03-05 09:47:27 -08:00
pablonyx
50ad066712 Better filtering (#4185)
* k

* k

* k

* k

* k
2025-03-05 04:35:50 +00:00
rkuo-danswer
870b59a1cc Bugfix/vertex crash (#4181)
* Update text embedding model to version 005 and enhance embedding retrieval process

* re

* Fix formatting issues

* Add support for Bedrock reranking provider and AWS credentials handling

* fix: improve AWS key format validation and error messages

* Fix vertex embedding model crash

* feat: add environment template for local development setup

* Add display name for Claude 3.7 Sonnet model

* Add display names for Gemini 2.0 models and update Claude 3.7 Sonnet entry

* Fix ruff errors by ensuring lines are within 130 characters

* revert to currently default onyx browser settings

* add / fix boto requirements

---------

Co-authored-by: ferdinand loesch <f.loesch@sportradar.com>
Co-authored-by: Ferdinand Loesch <ferdinandloesch@me.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-05 01:59:46 +00:00
pablonyx
5c896cb0f7 add minor fixes (#4170) 2025-03-04 20:29:28 +00:00
pablonyx
184b30643d Nit: logging adjustments (#4182) 2025-03-04 11:39:53 -08:00
pablonyx
ae585fd84c Delete all chats (#4171)
* nit

* k
2025-03-04 10:00:08 -08:00
rkuo-danswer
61e8f371b9 fix blowing up the entire task on exception and trying to reuse an in… (#4179)
* fix blowing up the entire task on exception and trying to reuse an invalid db session

* list comprehension

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-04 00:57:27 +00:00
rkuo-danswer
33cc4be492 Bugfix/GitHub validation (#4173)
* fixing unexpected errors disabling connectors

* rename UnexpectedError to UnexpectedValidationError

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-04 00:09:49 +00:00
joachim-danswer
117c8c0d78 Enable ephemeral message responses by Onyx Slack Bots (#4142)
A new setting 'is_ephemeral' has been added to the Slack channel configurations. 

Key features/effects:

  - if is_ephemeral is set for standard channel (and a Search Assistant is chosen):
     - the answer is only shown to user as an ephemeral message
     - the user has access to his private documents for a search (as the answer is only shown to them) 
     - the user has the ability to share the answer with the channel or keep private
     - a recipient list cannot be defined if the channel is set up as ephemeral
 
  - if is_ephemeral is set and DM with bot:
    - the user has access to private docs in searches
    - the message is not sent as ephemeral, as it is a 1:1 discussion with bot

 - if is_ephemeral is not set but recipient list is set:
    - the user search does *not* have access to their private documents as the information goes to the recipient list team members, and they may have different access rights

 - Overall:
     - Unless the channel is set to is_ephemeral or it is a direct conversation with the Bot, only public docs are accessible  
     - The ACL is never bypassed, also not in cases where the admin explicitly attached a document set to the bot config.
2025-03-03 15:02:21 -08:00
rkuo-danswer
9bb8cdfff1 fix web connector tests to handle new deduping (#4175)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-03 20:54:20 +00:00
Weves
a52d0d29be Small tweak to NumberInput 2025-03-03 11:20:53 -08:00
Chris Weaver
f25e1e80f6 Add option to not re-index (#4157)
* Add option to not re-index

* Add quantizaton / dimensionality override support

* Fix build / ut
2025-03-03 10:54:11 -08:00
Yuhong Sun
39fd6919ad Fix web scrolling 2025-03-03 09:00:05 -08:00
Yuhong Sun
7f0653d173 Handling of #! sites (#4169) 2025-03-03 08:18:44 -08:00
SubashMohan
e9905a398b Enhance iframe content extraction and add thresholds for JavaScript disabled scenarios (#4167) 2025-03-02 19:29:10 -08:00
Brad Slavin
3ed44e8bae Update Unstructured documentation URL to new location (#4168) 2025-03-02 19:16:38 -08:00
pablonyx
64158a5bdf silence_logs (#4165) 2025-03-02 19:00:59 +00:00
pablonyx
afb2393596 fix dark mode index attempt failure (#4163) 2025-03-02 01:23:16 +00:00
pablonyx
d473c4e876 Fix curator default persona editing (#4158)
* k

* k
2025-03-02 00:40:14 +00:00
pablonyx
692058092f fix typo 2025-03-01 13:00:07 -08:00
pablonyx
e88325aad6 bump version (#4164) 2025-03-01 01:58:45 +00:00
pablonyx
7490250e91 Fix user group edge case (#4159)
* fix user group

* k
2025-02-28 23:55:21 +00:00
pablonyx
e5369fcef8 Update warning copy (#4160)
* k

* k

* quick nit
2025-02-28 23:46:21 +00:00
Yuhong Sun
b0f00953bc Add CODEOWNERS 2025-02-28 13:57:33 -08:00
rkuo-danswer
f6a75c86c6 Bugfix/emit background error (#4156)
* print the test name when it runs

* type hints

* can't reuse session after an exception

* better logging

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-28 18:35:24 +00:00
pablonyx
ed9989282f nit- update casing enforcement on frontend 2025-02-28 10:09:06 -08:00
pablonyx
e80a0f2716 Improved google connector flow (#4155)
* fix handling

* k

* k

* fix function

* k

* k
2025-02-28 05:13:39 +00:00
rkuo-danswer
909403a648 Feature/confluence oauth (#3477)
* first cut at slack oauth flow

* fix usage of hooks

* fix button spacing

* add additional error logging

* no dev redirect

* early cut at google drive oauth

* second pass

* switch to production uri's

* try handling oauth_interactive differently

* pass through client id and secret if uploaded

* fix call

* fix test

* temporarily disable check for testing

* Revert "temporarily disable check for testing"

This reverts commit 4b5a022a5f.

* support visibility in test

* missed file

* first cut at confluence oauth

* work in progress

* work in progress

* work in progress

* work in progress

* work in progress

* first cut at distributed locking

* WIP to make test work

* add some dev mode affordances and gate usage of redis behind dynamic credentials

* mypy and credentials provider fixes

* WIP

* fix created at

* fix setting initialValue on everything

* remove debugging, fix ??? some TextFormField issues

* npm fixes

* comment cleanup

* fix comments

* pin the size of the card section

* more review fixes

* more fixes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-28 03:48:51 +00:00
pablonyx
cd84b65011 quick fix (#4154) 2025-02-28 02:03:34 +00:00
pablonyx
413f21cec0 Filter assistants fix (#4153)
* k

* quick nit

* minor assistant filtering fix
2025-02-28 02:03:21 +00:00
pablonyx
eb369384a7 Log server side auth error + slackbot pagination fix (#4149) 2025-02-27 18:05:28 -08:00
pablonyx
0a24dbc52c k# Please enter the commit message for your changes. Lines starting (#4144) 2025-02-27 23:34:20 +00:00
pablonyx
a7ba0da8cc Lowercase multi tenant email mapping (#4141) 2025-02-27 15:33:40 -08:00
Richard Kuo (Danswer)
aaced6d551 scan images 2025-02-27 15:25:29 -08:00
Richard Kuo (Danswer)
4c230f92ea trivy test 2025-02-27 15:05:03 -08:00
Richard Kuo (Danswer)
07d75b04d1 enable trivy scan 2025-02-27 14:22:44 -08:00
evan-danswer
a8d10750c1 fix propagation of is_agentic (#4150) 2025-02-27 11:56:51 -08:00
pablonyx
85e3ed57f1 Order chat sessions by time updated, not created (#4143)
* order chat sessions by time updated, not created

* quick update

* k
2025-02-27 17:35:42 +00:00
pablonyx
e10cc8ccdb Multi tenant user google auth fix (#4145) 2025-02-27 10:35:38 -08:00
pablonyx
7018bc974b Better looking errors (#4050)
* add error handling

* fix

* k
2025-02-27 04:58:25 +00:00
pablonyx
9c9075d71d Minor improvements to provisioning (#4109)
* quick fix

* k

* nit
2025-02-27 04:57:31 +00:00
pablonyx
338e084062 Improved tenant handling for slack bot (#4099) 2025-02-27 04:06:26 +00:00
pablonyx
2f64031f5c Improved tenant handling for slack bot1 (#4104) 2025-02-27 03:40:50 +00:00
pablonyx
abb74f2eaa Improved chat search (#4137)
* functional + fast

* k

* adapt

* k

* nit

* k

* k

* fix typing

* k
2025-02-27 02:27:45 +00:00
pablonyx
a3e3d83b7e Improve viewable assistant logic (#4125)
* k

* quick fix

* k
2025-02-27 01:24:39 +00:00
pablonyx
4dc88ca037 debug playwright failure case 2025-02-26 17:32:26 -08:00
rkuo-danswer
11e7e1c4d6 log processed tenant count (#4139)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-26 17:26:48 -08:00
pablonyx
f2d74ce540 Address Auth Edge Case (#4138) 2025-02-26 17:24:23 -08:00
rkuo-danswer
25389c5120 first cut at anonymizing query history (#4123)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-26 21:32:01 +00:00
pablonyx
ad0721ecd8 update (#4086) 2025-02-26 18:12:07 +00:00
pablonyx
426a8842ae Markdown copying / html formatting (#4120)
* k

* delete unnecessary util
2025-02-26 04:56:38 +00:00
pablonyx
a98dcbc7de Update tenant logic (#4122)
* k

* k

* k

* quick nit

* nit
2025-02-26 03:53:46 +00:00
pablonyx
6f389dc100 Improve lengthy chats (#4126)
* remove scroll

* working well

* nit

* k

* nit
2025-02-26 03:22:21 +00:00
pablonyx
d56177958f fix email headers (#4100) 2025-02-26 03:12:30 +00:00
Kaveen Jayamanna
0e42ae9024 Content of .xlsl are not properly read during indexing. (#4035) 2025-02-25 21:10:47 -08:00
Weves
ce2b4de245 temp remove 2025-02-25 20:46:55 -08:00
Chris Weaver
a515aa78d2 Fix confluence test (#4130) 2025-02-26 03:03:54 +00:00
Weves
23073d91b9 reduce number of chars to index for search 2025-02-25 19:27:50 -08:00
Chris Weaver
f767b1f476 Fix confluence permission syncing at scale (#4129)
* Fix confluence permission syncing at scale

* Remove line

* Better log message

* Adjust log
2025-02-25 19:22:52 -08:00
pablonyx
9ffc8cb2c4 k 2025-02-25 18:15:49 -08:00
pablonyx
98bfb58147 Handle bad slack configurations– multi tenant (#4118)
* k

* quick nit

* k

* k
2025-02-25 22:22:54 +00:00
evan-danswer
6ce810e957 faster indexing status at scale plus minor cleanups (#4081)
* faster indexing status at scale plus minor cleanups

* mypy

* address chris comments

* remove extra prints
2025-02-25 21:22:26 +00:00
pablonyx
07b0b57b31 (nit) bump timeout 2025-02-25 14:10:30 -08:00
pablonyx
118cdd7701 Chat search (#4113)
* add chat search

* don't add the bible

* base functional

* k

* k

* functioning

* functioning well

* functioning well

* k

* delete bible

* quick cleanup

* quick cleanup

* k

* fixed frontend hooks

* delete bible

* nit

* nit

* nit

* fix build

* k

* improved debouncing

* address comments

* fix alembic

* k
2025-02-25 20:49:46 +00:00
rkuo-danswer
ac83b4c365 validate connector deletion (#4108)
* validate connector deletion

* fixes

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-25 20:35:21 +00:00
pablonyx
fa408ff447 add 3.7 (#4116) 2025-02-25 12:41:40 -08:00
rkuo-danswer
4aa8eb8b75 fix scrolling test (#4117)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-25 10:23:04 -08:00
rkuo-danswer
60bd9271f7 Bugfix/model tests (#4092)
* trying out a fix

* add ability to manually run model tests

* add log dump

* check status code, not text?

* just the model server

* add port mapping to host

* pass through more api keys

* add azure tests

* fix litellm env vars

* fix env vars in github workflow

* temp disable litellm test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-25 04:53:51 +00:00
Weves
5d58a5e3ea Add ability to index all of Github 2025-02-24 18:56:36 -08:00
Chris Weaver
a99dd05533 Add option to index all Jira projects (#4106)
* Add option to index all Jira projects

* Fix test

* Fix web build

* Address comment
2025-02-25 02:07:00 +00:00
pablonyx
0dce67094e Prettier formatting for bedrock (#4111)
* k

* k
2025-02-25 02:05:29 +00:00
pablonyx
ffd14435a4 Text overflow logic (#4051)
* proper components

* k

* k

* k
2025-02-25 01:05:22 +00:00
rkuo-danswer
c9a3b45ad4 more aggressive handling of tasks blocking deletion (#4093)
* more aggressive handling of tasks blocking deletion

* comment updated

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-24 22:41:13 +00:00
pablonyx
7d40676398 Heavy task improvements, logging, and validation (#4058) 2025-02-24 13:48:53 -08:00
rkuo-danswer
b9e79e5db3 tighten up logs (#4076)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-24 19:23:00 +00:00
rkuo-danswer
558bbe16e4 Bugfix/termination cleanup (#4077)
* move activity timeout cleanup to the function exit

* fix excessive logging

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-24 19:21:55 +00:00
evan-danswer
076619ce2c make Settings model match db (#4087) 2025-02-24 19:04:36 +00:00
pablonyx
1263e21eb5 k (#4102) 2025-02-24 17:44:18 +00:00
pablonyx
f0c13b6558 fix starter message editing (#4101) 2025-02-24 01:01:01 +00:00
evan-danswer
a7125662f1 Fix gpt o-series code block formatting (#4089)
* prompt addition for gpt o-series to encourage markdown formatting of code blocks

* fix to match https://simonwillison.net/tags/markdown/

* chris comment

* chris comment
2025-02-24 00:59:48 +00:00
evan-danswer
4a4e4a6c50 thread utils respect contextvars (#4074)
* thread utils respect contextvars now

* address pablo comments

* removed tenant id from places it was already being passed

* fix rate limit check and pablo comment
2025-02-24 00:43:21 +00:00
pablonyx
1f2af373e1 improve scroll (#4096) 2025-02-23 19:20:07 +00:00
Weves
bdaa293ae4 Fix nginx for prod compose file 2025-02-21 16:57:54 -08:00
pablonyx
5a131f4547 Fix integration tests (#4059) 2025-02-21 15:56:11 -08:00
rkuo-danswer
ffb7d5b85b enable manual testing for model server (#4003)
* trying out a fix

* add ability to manually run model tests

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-21 14:00:32 -08:00
rkuo-danswer
fe8a5d671a don't spam the logs with texts on auth errors (#4085)
* don't spam the logs with texts on auth errors

* refactor the logging a bit

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-21 13:40:07 -08:00
Yuhong Sun
6de53ebf60 README Touchup (#4088) 2025-02-21 13:31:07 -08:00
rkuo-danswer
61d536c782 tool fixes (#4075) 2025-02-21 12:30:33 -08:00
Chris Weaver
e1ff9086a4 Fix LLM selection (#4078) 2025-02-21 11:32:57 -08:00
evan-danswer
ba21bacbbf coerce useLanggraph to boolean (#4084)
* coerce useLanggraph to boolean
2025-02-21 09:43:46 -08:00
pablonyx
158bccc3fc Default on for non-ee (#4083) 2025-02-21 09:11:45 -08:00
Weves
599b7705c2 Fix gitbook connector issues 2025-02-20 15:29:11 -08:00
rkuo-danswer
4958a5355d try more efficient query (#4047) 2025-02-20 12:58:50 -08:00
Chris Weaver
c4b8519381 Add support for sending email invites for single tenant users (#4065) 2025-02-19 21:05:23 -08:00
448 changed files with 18412 additions and 6075 deletions

1
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1 @@
* @onyx-dot-app/onyx-core-team

View File

@@ -12,29 +12,40 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# 1) Preliminary job to check if the changed files are relevant
# Bypassing this for now as the idea of not building is glitching
# releases and builds that depends on everything being tagged in docker
# 1) Preliminary job to check if the changed files are relevant
# check_model_server_changes:
# runs-on: ubuntu-latest
# outputs:
# changed: ${{ steps.check.outputs.changed }}
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
#
# - name: Check if relevant files changed
# id: check
# run: |
# # Default to "false"
# echo "changed=false" >> $GITHUB_OUTPUT
#
# # Compare the previous commit (github.event.before) to the current one (github.sha)
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# # set changed=true
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
# echo "changed=true" >> $GITHUB_OUTPUT
# fi
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.check.outputs.changed }}
changed: "true"
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
- name: Bypass check and set output
run: echo "changed=true" >> $GITHUB_OUTPUT
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'

View File

@@ -53,24 +53,90 @@ jobs:
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
if: always()
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@0.29.0
# with:
# sarif_file: trivy-results.sarif
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -145,7 +145,7 @@ jobs:
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
@@ -157,6 +157,7 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
@@ -199,7 +200,7 @@ jobs:
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |

View File

@@ -1,6 +1,7 @@
name: Connector Tests
on:
merge_group:
pull_request:
branches: [main]
schedule:
@@ -47,11 +48,13 @@ env:
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
@@ -76,7 +79,7 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -1,18 +1,29 @@
name: Connector Tests
name: Model Server Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
workflow_dispatch:
inputs:
branch:
description: 'Branch to run the workflow on'
required: false
default: 'main'
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# OpenAI
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
jobs:
model-check:
@@ -26,6 +37,23 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Model Server Docker image
run: |
docker pull onyxdotapp/onyx-model-server:latest
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
- name: Set up Python
uses: actions/setup-python@v5
with:
@@ -41,6 +69,49 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
@@ -56,3 +127,23 @@ jobs:
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v

View File

@@ -26,12 +26,12 @@
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
<h3>Feature Showcase</h3>
<h3>Feature Highlights</h3>
**Deep research over your team's knowledge:**
@@ -63,22 +63,21 @@ We also have built-in support for high-availability/scalable deployment on Kuber
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🚧 Roadmap
- Extensions to the Chrome Plugin
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models only through Onyx + learn from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:
@@ -115,3 +114,4 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -0,0 +1,125 @@
"""Update GitHub connector repo_name to repositories
Revision ID: 3934b1bc7b62
Revises: b7c2b63c4a03
Create Date: 2025-03-05 10:50:30.516962
"""
from alembic import op
import sqlalchemy as sa
import json
import logging
# revision identifiers, used by Alembic.
revision = "3934b1bc7b62"
down_revision = "b7c2b63c4a03"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
# First get all GitHub connectors
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
# Update each connector's config
updated_count = 0
for connector_id, config in github_connectors:
try:
if not config:
logger.warning(f"Connector {connector_id} has no config, skipping")
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repo_name" not in config:
continue
# Create new config with repositories instead of repo_name
new_config = dict(config)
repo_name_value = new_config.pop("repo_name")
new_config["repositories"] = repo_name_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
updated_count += 1
except Exception as e:
logger.error(f"Error updating connector {connector_id}: {str(e)}")
def downgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
logger.debug(
"Starting rollback of GitHub connectors from repositories to repo_name"
)
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
# Revert each GitHub connector to use repo_name instead of repositories
reverted_count = 0
for connector_id, config in github_connectors:
try:
if not config:
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repositories" not in config:
continue
# Create new config with repo_name instead of repositories
new_config = dict(config)
repositories_value = new_config.pop("repositories")
new_config["repo_name"] = repositories_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"new_config": json.dumps(new_config), "connector_id": connector_id},
)
reverted_count += 1
except Exception as e:
logger.error(f"Error reverting connector {connector_id}: {str(e)}")

View File

@@ -0,0 +1,84 @@
"""improved index
Revision ID: 3bd4c84fe72f
Revises: 8f43500ee275
Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
down_revision = "8f43500ee275"
branch_labels = None
depends_on = None
# NOTE:
# This migration addresses issues with the previous migration (8f43500ee275) which caused
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)
def downgrade() -> None:
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -0,0 +1,32 @@
"""add index
Revision ID: 8f43500ee275
Revises: da42808081e3
Create Date: 2025-02-24 17:35:33.072714
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8f43500ee275"
down_revision = "da42808081e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
# op.execute(
# """
# CREATE INDEX idx_chat_message_message_lower
# ON chat_message (LOWER(substring(message, 1, 1500)))
# """
# )
pass
def downgrade() -> None:
# Drop the index
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -0,0 +1,55 @@
"""add background_reindex_enabled field
Revision ID: b7c2b63c4a03
Revises: f11b408e39d3
Create Date: 2024-03-26 12:34:56.789012
"""
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import EmbeddingPrecision
# revision identifiers, used by Alembic.
revision = "b7c2b63c4a03"
down_revision = "f11b408e39d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add background_reindex_enabled column with default value of True
op.add_column(
"search_settings",
sa.Column(
"background_reindex_enabled",
sa.Boolean(),
nullable=False,
server_default="true",
),
)
# Add embedding_precision column with default value of FLOAT
op.add_column(
"search_settings",
sa.Column(
"embedding_precision",
sa.Enum(EmbeddingPrecision, native_enum=False),
nullable=False,
server_default=EmbeddingPrecision.FLOAT.name,
),
)
# Add reduced_dimension column with default value of None
op.add_column(
"search_settings",
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
)
def downgrade() -> None:
# Remove the background_reindex_enabled column
op.drop_column("search_settings", "background_reindex_enabled")
op.drop_column("search_settings", "embedding_precision")
op.drop_column("search_settings", "reduced_dimension")

View File

@@ -0,0 +1,120 @@
"""migrate jira connectors to new format
Revision ID: da42808081e3
Revises: f13db29f3101
Create Date: 2025-02-24 11:24:54.396040
"""
from alembic import op
import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.
revision = "da42808081e3"
down_revision = "f13db29f3101"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config
for connector_id, old_config in jira_connectors:
if not old_config:
continue
# Extract project key from URL if it exists
new_config: dict[str, str | None] = {}
if project_url := old_config.get("jira_project_url"):
# Parse the URL to get base and project
try:
jira_base, project_key = extract_jira_project(project_url)
new_config = {"jira_base_url": jira_base, "project_key": project_key}
except ValueError:
# If URL parsing fails, just use the URL as the base
new_config = {
"jira_base_url": project_url.split("/projects/")[0],
"project_key": None,
}
else:
# For connectors without a project URL, we need admin intervention
# Mark these for review
print(
f"WARNING: Jira connector {connector_id} has no project URL configured"
)
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :id
"""
),
{"id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config back to the old format
for connector_id, new_config in jira_connectors:
if not new_config:
continue
old_config = {}
base_url = new_config.get("jira_base_url")
project_key = new_config.get("project_key")
if base_url and project_key:
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
elif base_url:
old_config = {"jira_project_url": base_url}
else:
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :old_config
WHERE id = :id
"""
),
{"id": connector_id, "old_config": old_config},
)

View File

@@ -0,0 +1,36 @@
"""force lowercase all users
Revision ID: f11b408e39d3
Revises: 3bd4c84fe72f
Create Date: 2025-02-26 17:04:55.683500
"""
# revision identifiers, used by Alembic.
revision = "f11b408e39d3"
down_revision = "3bd4c84fe72f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing user emails to lowercase
from alembic import op
op.execute(
"""
UPDATE "user"
SET email = LOWER(email)
"""
)
# 2) Add a check constraint to ensure emails are always lowercase
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
def downgrade() -> None:
# Drop the check constraint
from alembic import op
op.drop_constraint("ensure_lowercase_email", "user", type_="check")

View File

@@ -0,0 +1,27 @@
"""Add composite index for last_modified and last_synced to document
Revision ID: f13db29f3101
Revises: b388730a2899
Create Date: 2025-02-18 22:48:11.511389
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f13db29f3101"
down_revision = "acaab4ef4507"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
op.create_index(
"ix_document_sync_status",
"document",
["last_modified", "last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index("ix_document_sync_status", table_name="document")

View File

@@ -0,0 +1,42 @@
"""lowercase multi-tenant user auth
Revision ID: 34e3630c7f32
Revises: a4f6ee863c47
Create Date: 2025-02-26 15:03:01.211894
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "34e3630c7f32"
down_revision = "a4f6ee863c47"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing rows to lowercase
op.execute(
"""
UPDATE user_tenant_mapping
SET email = LOWER(email)
"""
)
# 2) Add a check constraint so that emails cannot be written in uppercase
op.create_check_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
"email = LOWER(email)",
schema="public",
)
def downgrade() -> None:
# Drop the check constraint
op.drop_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
schema="public",
type_="check",
)

View File

@@ -0,0 +1,51 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -4,12 +4,11 @@ from ee.onyx.server.reporting.usage_export_generation import create_new_usage_re
from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_sessions_older_than
from onyx.db.engine import get_session_with_tenant
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -18,11 +17,28 @@ logger = setup_logger()
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
try:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
#####
@@ -35,24 +51,19 @@ def perform_ttl_management_task(
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str | None) -> None:
def check_ttl_management_task(*, tenant_id: str) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@celery_app.task(
@@ -60,9 +71,9 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,

View File

@@ -18,7 +18,7 @@ logger = setup_logger()
def monitor_usergroup_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
"""This function is likely to move in the worker refactor happening next."""
fence_key = key_bytes.decode("utf-8")

View File

@@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
)
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
)
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""

View File

@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import UserGroup__ConnectorCredentialPair
@@ -35,10 +36,11 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
def get_cc_pairs_by_source(
db_session: Session,
source_type: DocumentSource,
only_sync: bool,
access_type: AccessType | None = None,
status: ConnectorCredentialPairStatus | None = None,
) -> list[ConnectorCredentialPair]:
"""
Get all cc_pairs for a given source type (and optionally only sync)
Get all cc_pairs for a given source type with optional filtering by access_type and status
result is sorted by cc_pair id
"""
query = (
@@ -48,8 +50,11 @@ def get_cc_pairs_by_source(
.order_by(ConnectorCredentialPair.id)
)
if only_sync:
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
if access_type is not None:
query = query.filter(ConnectorCredentialPair.access_type == access_type)
if status is not None:
query = query.filter(ConnectorCredentialPair.status == status)
cc_pairs = query.all()
return cc_pairs

View File

@@ -134,7 +134,9 @@ def fetch_chat_sessions_eagerly_by_time(
limit: int | None = 500,
initial_time: datetime | None = None,
) -> list[ChatSession]:
time_order: UnaryExpression = desc(ChatSession.time_created)
"""Sorted by oldest to newest, then by message id"""
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)
filters: list[ColumnElement | BinaryExpression] = [
@@ -147,8 +149,7 @@ def fetch_chat_sessions_eagerly_by_time(
subquery = (
db_session.query(ChatSession.id, ChatSession.time_created)
.filter(*filters)
.order_by(ChatSession.id, time_order)
.distinct(ChatSession.id)
.order_by(asc_time_order)
.limit(limit)
.subquery()
)
@@ -164,7 +165,7 @@ def fetch_chat_sessions_eagerly_by_time(
ChatMessage.chat_message_feedbacks
),
)
.order_by(time_order, message_order)
.order_by(asc_time_order, message_order)
)
chat_sessions = query.all()

View File

@@ -16,13 +16,20 @@ from onyx.db.models import UsageReport
from onyx.file_store.file_store import get_default_file_store
# Gets skeletons of all message
# Gets skeletons of all messages in the given range
def get_empty_chat_messages_entries__paginated(
db_session: Session,
period: tuple[datetime, datetime],
limit: int | None = 500,
initial_time: datetime | None = None,
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
"""Returns a tuple where:
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
Only messages of type USER are returned
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],
end=period[1],
@@ -52,18 +59,17 @@ def get_empty_chat_messages_entries__paginated(
if len(chat_sessions) == 0:
return None, []
return chat_sessions[0].time_created, message_skeletons
return chat_sessions[-1].time_created, message_skeletons
def get_all_empty_chat_message_entries(
db_session: Session,
period: tuple[datetime, datetime],
) -> Generator[list[ChatMessageSkeleton], None, None]:
"""period is the range of time over which to fetch messages."""
initial_time: Optional[datetime] = period[0]
ind = 0
while True:
ind += 1
# iterate from oldest to newest
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
db_session,
period,

View File

@@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC
# otherwise, set their role to BASIC only if they were previously a CURATOR
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
@@ -631,7 +631,16 @@ def update_user_group(
removed_users = db_session.scalars(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
_validate_curator_status__no_commit(db_session, list(removed_users))
# Filter out admin and global curator users before validating curator status
users_to_validate = [
user
for user in removed_users
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
]
if users_to_validate:
_validate_curator_status__no_commit(db_session, users_to_validate)
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()

View File

@@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -342,7 +346,8 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -354,7 +359,11 @@ def confluence_doc_sync(
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), "confluence", cc_pair.credential_id
)
confluence_connector.set_credentials_provider(provider)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)

View File

@@ -1,9 +1,11 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -61,13 +63,27 @@ def _build_group_member_email_map(
def confluence_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
confluence_client = build_confluence_client(
credentials=cc_pair.credential.credential_json,
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
)
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
url = wiki_base.rstrip("/")
probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
confluence_client = OnyxConfluence(is_cloud, url, provider)
confluence_client._probe_connection(**probe_kwargs)
confluence_client._initialize_connection(**final_kwargs)
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,

View File

@@ -32,7 +32,8 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -62,12 +62,14 @@ def _fetch_permissions_for_permission_ids(
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain)",
supportsAllDrives=True,
continue_on_404_or_403=True,
)
permissions_for_doc_id = []
@@ -104,7 +106,13 @@ def _get_permissions_from_slim_doc(
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
skipped_permissions = 0
for permission in permissions_list:
if not permission:
skipped_permissions += 1
continue
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
@@ -121,6 +129,11 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
if skipped_permissions > 0:
logger.warning(
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
)
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
@@ -132,7 +145,8 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -119,6 +119,7 @@ def _build_onyx_groups(
def gdrive_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects

View File

@@ -123,7 +123,8 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -28,6 +28,7 @@ DocSyncFuncType = Callable[
GroupSyncFuncType = Callable[
[
str,
ConnectorCredentialPair,
],
list[ExternalUserGroup],

View File

@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth import router as oauth_router
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -128,7 +128,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, oauth_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(
@@ -152,4 +152,8 @@ def get_application() -> FastAPI:
# environment variable. Used to automate deployment for multiple environments.
seed_db()
# for debugging discovered routes
# for route in application.router.routes:
# print(f"Path: {route.path}, Methods: {route.methods}")
return application

View File

@@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.utils.logger import OnyxLoggingAdapter
from onyx.utils.logger import setup_logger
@@ -216,7 +216,7 @@ def _handle_standard_answers(
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
@@ -231,6 +231,7 @@ def _handle_standard_answers(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
receiver_ids=receiver_ids,
)
return True

View File

@@ -1,629 +0,0 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/oauth")
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.onyx.app/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class ConfluenceCloudOAuth:
"""work in progress"""
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
tenant_id = get_current_tenant_id()
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
# email=user.email, redirect_on_success=redirect_on_success
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": SlackOAuth.REDIRECT_URI,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"team_id": team_id,
"authed_user_id": authed_user_id,
"redirect_on_success": session.redirect_on_success,
}
)
# Work in progress
# @router.post("/connector/confluence/callback")
# def handle_confluence_oauth_callback(
# code: str,
# state: str,
# user: User = Depends(current_user),
# db_session: Session = Depends(get_session),
# tenant_id: str | None = Depends(get_current_tenant_id),
# ) -> JSONResponse:
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
# raise HTTPException(
# status_code=500,
# detail="Confluence client ID or client secret is not configured."
# )
# r = get_redis_client(tenant_id=tenant_id)
# # recover the state
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
# # Convert bytes back to a UUID
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
# oauth_uuid_str = str(oauth_uuid)
# r_key = f"da_oauth:{oauth_uuid_str}"
# result = r.get(r_key)
# if not result:
# raise HTTPException(
# status_code=400,
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
# )
# try:
# session = ConfluenceCloudOAuth.parse_session(result)
# # Exchange the authorization code for an access token
# response = requests.post(
# ConfluenceCloudOAuth.TOKEN_URL,
# headers={"Content-Type": "application/x-www-form-urlencoded"},
# data={
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
# "code": code,
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
# },
# )
# response_data = response.json()
# if not response_data.get("ok"):
# raise HTTPException(
# status_code=400,
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
# )
# # Extract token and team information
# access_token: str = response_data.get("access_token")
# team_id: str = response_data.get("team", {}).get("id")
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
# credential_info = CredentialBase(
# credential_json={"slack_bot_token": access_token},
# admin_public=True,
# source=DocumentSource.CONFLUENCE,
# name="Confluence OAuth",
# )
# logger.info(f"Slack access token: {access_token}")
# credential = create_credential(credential_info, user, db_session)
# logger.info(f"new_credential_id={credential.id}")
# except Exception as e:
# return JSONResponse(
# status_code=500,
# content={
# "success": False,
# "message": f"An error occurred during Slack OAuth: {str(e)}",
# },
# )
# finally:
# r.delete(r_key)
# # return the result
# return JSONResponse(
# content={
# "success": True,
# "message": "Slack OAuth completed successfully.",
# "team_id": team_id,
# "authed_user_id": authed_user_id,
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -0,0 +1,91 @@
import base64
import uuid
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from ee.onyx.server.oauth.api_router import router
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_current_tenant_id
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str | None = None
if connector == DocumentSource.SLACK:
if not DEV_MODE:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.CONFLUENCE:
if not DEV_MODE:
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
session = ConfluenceCloudOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
if not DEV_MODE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
if not session:
raise HTTPException(
status_code=500,
detail=f"The document source type {connector} failed to generate an OAuth session.",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})

View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router: APIRouter = APIRouter(prefix="/oauth")

View File

@@ -0,0 +1,362 @@
import base64
import uuid
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ConfluenceCloudOAuth:
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
class AccessibleResources(BaseModel):
id: str
name: str
url: str
scopes: list[str]
avatarUrl: str
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
ACCESSIBLE_RESOURCE_URL = (
"https://api.atlassian.com/oauth/token/accessible-resources"
)
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
# classic scope
"read:confluence-space.summary%20"
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence%20"
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"read:content-details:confluence%20" # for permission sync
"offline_access"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&redirect_uri={redirect_uri}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
return session
@classmethod
def generate_finalize_url(cls, credential_id: int) -> str:
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
@router.post("/connector/confluence/callback")
def confluence_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Handles the backend logic for the frontend page that the user is redirected to
after visiting the oauth authorization url."""
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Confluence Cloud client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = ConfluenceCloudOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
else:
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
ConfluenceCloudOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
try:
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
response.text
)
except Exception:
raise RuntimeError(
"Confluence Cloud OAuth failed during code/token exchange."
)
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
credential_info = CredentialBase(
credential_json={
"confluence_access_token": token_response.access_token,
"confluence_refresh_token": token_response.refresh_token,
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"expires_in": token_response.expires_in,
"scope": token_response.scope,
},
admin_public=True,
source=DocumentSource.CONFLUENCE,
name="Confluence Cloud OAuth",
)
credential = create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth completed successfully.",
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
"redirect_on_success": session.redirect_on_success,
}
)
@router.get("/connector/confluence/accessible-resources")
def confluence_oauth_accessible_resources(
credential_id: int,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Atlassian's API is weird and does not supply us with enough info to be in a
usable state after authorizing. All API's require a cloud id. We have to list
the accessible resources/sites and let the user choose which site to use."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(400, f"Credential {credential_id} not found.")
credential_dict = credential.credential_json
access_token = credential_dict["confluence_access_token"]
try:
# Exchange the authorization code for an access token
response = requests.get(
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
)
response.raise_for_status()
accessible_resources_data = response.json()
# Validate the list of AccessibleResources
try:
accessible_resources = [
ConfluenceCloudOAuth.AccessibleResources(**resource)
for resource in accessible_resources_data
]
except ValidationError as e:
raise RuntimeError(f"Failed to parse accessible resources: {e}")
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud get accessible resources completed successfully.",
"accessible_resources": [
resource.model_dump() for resource in accessible_resources
],
}
)
@router.post("/connector/confluence/finalize")
def confluence_oauth_finalize(
credential_id: int,
cloud_id: str,
cloud_name: str,
cloud_url: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Saves the info for the selected cloud site to the credential.
This is the final step in the confluence oauth flow where after the traditional
OAuth process, the user has to select a site to associate with the credentials.
After this, the credential is usable."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
)
new_credential_json: dict[str, Any] = dict(credential.credential_json)
new_credential_json["cloud_id"] = cloud_id
new_credential_json["cloud_name"] = cloud_name
new_credential_json["wiki_base"] = cloud_url
try:
update_credential_json(credential_id, new_credential_json, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth finalized successfully.",
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
}
)

View File

@@ -0,0 +1,229 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
else:
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -0,0 +1,197 @@
import base64
import uuid
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.danswer.dev/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = SlackOAuth.REDIRECT_URI
else:
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
"team_id": team_id,
"authed_user_id": authed_user_id,
}
)

View File

@@ -1,10 +1,14 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@@ -14,13 +18,19 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -89,6 +99,12 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@@ -97,6 +113,15 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -113,11 +138,104 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,3 +1,5 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -9,6 +11,7 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -88,6 +91,64 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -107,6 +168,12 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
def _check_token_rate_limits(user: User | None) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global(tenant_id)
_user_is_rate_limited_by_global()
elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global(tenant_id)
_user_is_rate_limited_by_global()
else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, (tenant_id,)),
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
]
)
@@ -52,8 +52,8 @@ User rate limits
"""
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
@@ -93,8 +93,8 @@ User Group rate limits
"""
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:

View File

@@ -2,6 +2,7 @@ import csv
import io
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from uuid import UUID
from fastapi import APIRouter
@@ -21,8 +22,10 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
@@ -35,6 +38,8 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
def fetch_and_process_chat_session_history(
db_session: Session,
@@ -43,10 +48,15 @@ def fetch_and_process_chat_session_history(
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session
# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
@@ -107,6 +117,17 @@ def get_user_chat_sessions(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
# we specifically don't allow this endpoint if "anonymized" since
# this is a direct query on the user id
if ONYX_QUERY_HISTORY_TYPE in [
QueryHistoryType.DISABLED,
QueryHistoryType.ANONYMIZED,
]:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Per user query history has been disabled by the administrator.",
)
try:
chat_sessions = get_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session, limit=0
@@ -122,6 +143,7 @@ def get_user_chat_sessions(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
@@ -141,6 +163,12 @@ def get_chat_session_history(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
@@ -157,11 +185,16 @@ def get_chat_session_history(
feedback_filter=feedback_type,
)
minimal_chat_sessions: list[ChatSessionMinimal] = []
for chat_session in page_of_chat_sessions:
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
minimal_chat_sessions.append(minimal_chat_session)
return PaginatedReturn(
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
items=minimal_chat_sessions,
total_items=total_filtered_chat_sessions_count,
)
@@ -172,6 +205,12 @@ def get_chat_session_admin(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionSnapshot:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
try:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
@@ -193,6 +232,9 @@ def get_chat_session_admin(
f"Could not create snapshot for chat session with id '{chat_session_id}'",
)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
return snapshot
@@ -203,6 +245,14 @@ def get_query_history_as_csv(
end: datetime | None = None,
db_session: Session = Depends(get_session),
) -> StreamingResponse:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
@@ -213,6 +263,9 @@ def get_query_history_as_csv(
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
for chat_session_snapshot in complete_chat_session_history:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
question_answer_pairs.extend(
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
)

View File

@@ -0,0 +1,45 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -0,0 +1,98 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,269 +1,24 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)

View File

@@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(tenant_id: str) -> BillingInformation:
def fetch_billing_information(
tenant_id: str,
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = BillingInformation(**response.json())
return billing_info
response_data = response.json()
# Check if the response indicates no subscription
if (
isinstance(response_data, dict)
and "subscribed" in response_data
and not response_data["subscribed"]
):
return SubscriptionStatusResponse(**response_data)
# Otherwise, parse as BillingInformation
return BillingInformation(**response_data)
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:

View File

@@ -0,0 +1,96 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,3 +67,30 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -48,4 +48,5 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}

View File

@@ -4,6 +4,7 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -14,6 +15,7 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -55,7 +57,11 @@ logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
"""
Get existing tenant ID for an email or create a new tenant if none exists.
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
@@ -104,14 +110,14 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
status_code=409, detail="User already belongs to an organization"
)
logger.info(f"Provisioning tenant: {tenant_id}")
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
if not create_schema_if_not_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
logger.debug(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -200,33 +206,15 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
def configure_default_api_keys(db_session: Session) -> None:
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
@@ -238,6 +226,26 @@ def configure_default_api_keys(db_session: Session) -> None:
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
@@ -347,3 +355,47 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None

View File

@@ -0,0 +1,67 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -0,0 +1,39 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -0,0 +1,90 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,34 +1,63 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = logging.getLogger(__name__)
logger = setup_logger()
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
@@ -38,17 +67,19 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
db_session.add(
UserTenantMapping(email=email, tenant_id=tenant_id, active=False)
)
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
@@ -71,8 +102,192 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

View File

@@ -6,7 +6,7 @@ MODEL_WARM_UP_STRING = "hi " * 512
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-004"
DEFAULT_VERTEX_MODEL = "text-embedding-005"
class EmbeddingModelTextType:

View File

@@ -5,6 +5,7 @@ from types import TracebackType
from typing import cast
from typing import Optional
import aioboto3 # type: ignore
import httpx
import openai
import vertexai # type: ignore
@@ -28,11 +29,13 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import pass_aws_key
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
@@ -59,6 +62,60 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -78,7 +135,7 @@ class CloudEmbedding:
self._closed = False
async def _embed_openai(
self, texts: list[str], model: str | None
self, texts: list[str], model: str | None, reduced_dimension: int | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
@@ -89,22 +146,17 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Error embedding text with OpenAI: {str(e)} \n"
f"Model: {model} \n"
f"Provider: {self.provider} \n"
f"Texts: {texts}"
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
logger.error(error_string)
raise RuntimeError(error_string)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -143,7 +195,6 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -173,17 +224,24 @@ class CloudEmbedding:
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
embeddings = await client.get_embeddings_async(
[
TextEmbeddingInput(
text,
embedding_type,
)
for text in texts
],
auto_truncate=True, # This is the default
)
return [embedding.values for embedding in embeddings]
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]
# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]
# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)
return [embedding.values for batch in results for embedding in batch]
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
@@ -218,23 +276,53 @@ class CloudEmbedding:
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
@staticmethod
def create(
@@ -321,6 +409,7 @@ async def embed_text(
prefix: str | None,
api_url: str | None,
api_version: str | None,
reduced_dimension: int | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
@@ -364,6 +453,7 @@ async def embed_text(
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
reduced_dimension=reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
@@ -435,7 +525,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
)
async def cohere_rerank(
async def cohere_rerank_api(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
@@ -445,6 +535,45 @@ async def cohere_rerank(
return [result.relevance_score for result in sorted_results]
async def cohere_rerank_aws(
query: str,
docs: list[str],
model_name: str,
region_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
) -> list[float]:
session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
async with session.client(
"bedrock-runtime", region_name=region_name
) as bedrock_client:
body = json.dumps(
{
"query": query,
"documents": docs,
"api_version": 2,
}
)
# Invoke the Bedrock model asynchronously
response = await bedrock_client.invoke_model(
modelId=model_name,
accept="application/json",
contentType="application/json",
body=body,
)
# Read the response asynchronously
response_body = json.loads(await response["body"].read())
# Extract and sort the results
results = response_body.get("results", [])
sorted_results = sorted(results, key=lambda item: item["index"])
return [result["relevance_score"] for result in sorted_results]
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
@@ -503,10 +632,18 @@ async def process_embed_request(
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
reduced_dimension=embed_request.reduced_dimension,
prefix=prefix,
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,
@@ -559,15 +696,32 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = await cohere_rerank(
sim_scores = await cohere_rerank_api(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
if rerank_request.api_key is None:
raise RuntimeError("Bedrock Rerank Requires an API Key")
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
rerank_request.api_key
)
sim_scores = await cohere_rerank_aws(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
region_name=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")
raise HTTPException(

View File

@@ -70,3 +70,32 @@ def get_gpu_type() -> str:
return GPUStatus.MAC_MPS
return GPUStatus.NONE
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
"""Parse AWS API key string into components.
Args:
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
Returns:
Tuple of (access_key, secret_key, region)
Raises:
ValueError: If key format is invalid
"""
if not api_key.startswith("aws"):
raise ValueError("API key must start with 'aws' prefix")
parts = api_key.split("_")
if len(parts) != 4:
raise ValueError(
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
"this is an onyx specific format for formatting the aws secrets for bedrock"
)
try:
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
return aws_access_key_id, aws_secret_access_key, aws_region
except Exception as e:
raise ValueError(f"Failed to parse AWS key components: {str(e)}")

View File

@@ -31,6 +31,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
@@ -92,6 +93,7 @@ def check_sub_answer(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)

View File

@@ -46,6 +46,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -119,6 +120,7 @@ def generate_sub_answer(
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -62,6 +63,7 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
@@ -153,8 +155,9 @@ def generate_initial_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -278,6 +281,9 @@ def generate_initial_answer(
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -34,6 +34,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
@@ -141,6 +142,7 @@ def decompose_orig_question(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),

View File

@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
@@ -112,6 +113,7 @@ def compare_answers(
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
@@ -144,6 +145,7 @@ def create_refined_sub_questions(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),

View File

@@ -50,13 +50,7 @@ def decide_refinement_need(
)
]
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)

View File

@@ -21,6 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
@@ -96,6 +97,7 @@ def extract_entities_terms(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -46,6 +46,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -68,6 +69,8 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
@@ -179,8 +182,9 @@ def generate_validate_refined_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -302,7 +306,11 @@ def generate_validate_refined_answer(
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -409,6 +417,7 @@ def generate_validate_refined_answer(
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),

View File

@@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
if result.query_info is not None:
query_info = result.query_info
break
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)
assert query_info is not None, "must have query info"
return query_info

View File

@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
@@ -96,6 +97,7 @@ def expand_queries(
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)

View File

@@ -56,8 +56,9 @@ def format_results(
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
pre_rerank_docs = callback_container[0] if callback_container else []
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@@ -25,6 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -93,6 +94,7 @@ def verify_documents(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
assert isinstance(response.content, str)

View File

@@ -44,7 +44,9 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)

View File

@@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -25,6 +34,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
@@ -37,6 +47,31 @@ def choose_tool(
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.search_request.query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.search_request.query,
)
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
@@ -47,7 +82,6 @@ def choose_tool(
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@@ -71,11 +105,22 @@ def choose_tool(
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
@@ -98,8 +143,16 @@ def choose_tool(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
tools=(
[tool.tool_definition() for tool in tools] or None
if using_tool_calling_llm
else None
),
tool_choice=(
"required"
if tools and force_use_tool.force_use and using_tool_calling_llm
else None
),
structured_response_format=structured_response_format,
)
@@ -145,10 +198,22 @@ def choose_tool(
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
@@ -50,11 +55,13 @@ def basic_use_tool_response(
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -2,6 +2,7 @@ from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"

View File

@@ -42,6 +42,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
@@ -61,6 +62,7 @@ from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
@@ -402,6 +404,7 @@ def summarize_history(
llm.invoke,
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - summarize history")
@@ -505,3 +508,9 @@ def get_deduplicated_structured_subquestion_documents(
cited_documents=dedup_inference_section_list(cited_docs),
context_documents=dedup_inference_section_list(context_docs),
)
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
return not (
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
)

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
from shared_configs.configs import MULTI_TENANT
_API_KEY_HEADER_NAME = "Authorization"
@@ -35,8 +36,7 @@ class ApiKeyDescriptor(BaseModel):
def generate_api_key(tenant_id: str | None = None) -> str:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
if not MULTI_TENANT or not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID

View File

@@ -2,6 +2,8 @@ import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
from email.utils import make_msgid
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
@@ -10,8 +12,10 @@ from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
@@ -151,6 +155,8 @@ def send_email(
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
@@ -172,7 +178,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
subject = "Your Onyx Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>Were sorry to see you go.</p>"
"<p>We're sorry to see you go.</p>"
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
"<p>If you change your mind, you can always come back!</p>"
)
@@ -187,36 +193,64 @@ def send_subscription_cancellation_email(user_email: str) -> None:
send_email(user_email, subject, html_content, text_content)
def send_user_email_invite(user_email: str, current_user: User) -> None:
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
subject = "Invitation to Join Onyx Organization"
heading = "You've Been Invited!"
message = (
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
elif auth_type == AuthType.BASIC:
message += (
"<p>To join the organization, please click the button below to set a password "
"and complete your registration.</p>"
)
elif auth_type == AuthType.GOOGLE_OAUTH:
message += (
"<p>To join the organization, please click the button below to login with Google "
"and complete your registration.</p>"
)
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
message += (
"<p>To join the organization, please click the button below to"
" complete your registration.</p>"
)
else:
raise ValueError(f"Invalid auth type: {auth_type}")
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
"You'll be asked to set a password or login with Google to complete your registration."
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)
def send_forgot_password_email(
user_email: str,
token: str,
tenant_id: str,
mail_from: str = EMAIL_FROM,
tenant_id: str | None = None,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if tenant_id:
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)

View File

@@ -1,5 +1,6 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -18,3 +19,17 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -100,6 +100,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
@@ -214,7 +215,7 @@ def verify_email_is_invited(email: str) -> None:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -411,7 +412,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"refresh_token": refresh_token,
}
user: User
user: User | None = None
try:
# Attempt to get user by OAuth account
@@ -420,15 +421,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = await self.get_by_email(account_email)
user = await self.user_db.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# Make sure user is not None before adding OAuth account
if user is not None:
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
else:
# This shouldn't happen since get_by_email would raise UserNotExists
# but adding as a safeguard
raise exceptions.UserNotExists()
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
@@ -439,26 +445,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.user_db.create(user_dict)
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
# Add OAuth account only if user creation was successful
if user is not None:
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
raise HTTPException(
status_code=500, detail="Failed to create user account"
)
else:
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# User exists, update OAuth account if needed
if user is not None: # Add explicit check
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# Ensure user is not None before proceeding
if user is None:
raise HTTPException(
status_code=500, detail="Failed to authenticate or create user"
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
@@ -508,6 +524,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
user_count = await get_user_count()
logger.debug(f"Current tenant user count: {user_count}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if user_count == 1:
@@ -529,7 +546,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.notice(f"User {user.id} has registered.")
logger.debug(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
data={"action": "create"},
@@ -553,7 +570,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async_return_default_schema,
)(email=user.email)
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
@@ -571,14 +588,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> Optional[User]:
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
)
tenant_id: str | None = None
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
None,
)(
email=email,
)
except Exception as e:
logger.warning(
f"User attempted to login with invalid credentials: {str(e)}"
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -872,7 +895,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accesssible_user(
async def current_chat_accessible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1073,6 +1096,12 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1104,9 +1133,14 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
redirect_response = RedirectResponse(next_url, status_code=302)
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
@@ -1118,6 +1152,7 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -2,6 +2,7 @@ import logging
import multiprocessing
import time
from typing import Any
from typing import cast
import sentry_sdk
from celery import Task
@@ -131,9 +132,9 @@ def on_task_postrun(
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
tenant_id = kwargs.get("tenant_id")
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "

View File

@@ -111,5 +111,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
]
)

View File

@@ -92,7 +92,8 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
"""This is a redis specific way to build a list of tasks in a queue and return them
as a set.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.

View File

@@ -34,7 +34,7 @@ def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
tenant_id: str,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
tenant_id: str,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id

View File

@@ -0,0 +1,73 @@
# backend/onyx/background/celery/memory_monitoring.py
import logging
import os
from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()
cpu_percent = process.cpu_percent(interval=0.1)
# Build metadata string from additional_metadata dictionary
metadata_str = " ".join(
[f"{key}={value}" for key, value in additional_metadata.items()]
)
metadata_str = f" {metadata_str}" if metadata_str else ""
memory_logger.info(
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
f"cpu={cpu_percent:.2f}{metadata_str}"
)
except Exception:
logger.exception("Error monitoring process memory.")

View File

@@ -8,16 +8,21 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
@@ -52,6 +57,51 @@ class TaskDependencyError(RuntimeError):
with connector deletion."""
def revoke_tasks_blocking_deletion(
redis_connector: RedisConnector, db_session: Session, app: Celery
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
try:
permissions_sync_payload = redis_connector.permissions.payload
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
app.control.revoke(permissions_sync_payload.celery_task_id)
task_logger.info(
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking pruning task")
try:
prune_payload = redis_connector.prune.payload
if prune_payload and prune_payload.celery_task_id:
app.control.revoke(prune_payload.celery_task_id)
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
except Exception:
task_logger.exception("Exception while revoking permissions sync task")
try:
external_group_sync_payload = redis_connector.external_group_sync.payload
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
app.control.revoke(external_group_sync_payload.celery_task_id)
task_logger.info(
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking external group sync task")
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
@@ -59,22 +109,36 @@ class TaskDependencyError(RuntimeError):
trail=False,
bind=True,
)
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
# Prevent this task from overlapping with itself
if not lock_beat.acquire(blocking=False):
return None
try:
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating connector deletion fences"
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
@@ -92,9 +156,38 @@ def check_for_connector_deletion_task(
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
# on the first error, we set a stop signal and revoke the dependent tasks
# on subsequent errors, we hard reset blocking fences after our specified timeout
# is exceeded
task_logger.info(str(e))
redis_connector.stop.set_fence(True)
if not redis_connector.stop.fenced:
# one time revoke of celery tasks
task_logger.info("Revoking any tasks blocking deletion.")
revoke_tasks_blocking_deletion(
redis_connector, db_session, self.app
)
redis_connector.stop.set_fence(True)
redis_connector.stop.set_timeout()
else:
# stop signal already set
if redis_connector.stop.timed_out:
# waiting too long, just reset blocking fences
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()
else:
# just wait
pass
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
@@ -129,7 +222,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
cc_pair_id: int,
db_session: Session,
lock_beat: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
@@ -169,6 +262,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
redis_connector.delete.set_active()
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
@@ -249,7 +343,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
tenant_id: str, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -401,3 +495,171 @@ def monitor_connector_deletion_taskset(
)
redis_connector.delete.reset()
def validate_connector_deletion_fences(
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# validate all existing connector deletion jobs
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
continue
validate_connector_deletion_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
r,
)
lock_beat.reacquire()
return
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.delete.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.delete.payload
except ValidationError:
task_logger.exception(
"validate_connector_deletion_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return
if not payload:
return
# OK, there's actually something for us to validate
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.delete.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_connector_deletion_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.delete.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.delete.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_connector_deletion_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return

View File

@@ -30,6 +30,7 @@ from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -42,8 +43,10 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
@@ -63,6 +66,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
@@ -193,12 +197,19 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
task_logger.info(f"check_for_doc_permissions_sync finished: tenant={tenant_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id} {error_msg}"
)
task_logger.exception(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id}"
)
finally:
if lock_beat.owned():
lock_beat.release()
@@ -210,7 +221,7 @@ def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
tenant_id: str,
) -> str | None:
"""Returns a randomized payload id on success.
Returns None if no syncing is required."""
@@ -282,13 +293,19 @@ def try_creating_permissions_sync_task(
redis_connector.permissions.set_fence(payload)
payload_id = payload.id
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_permissions_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_permissions_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -303,7 +320,7 @@ def try_creating_permissions_sync_task(
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
tenant_id: str,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
@@ -388,6 +405,29 @@ def connector_permission_sync_generator_task(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -439,6 +479,10 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
)
task_logger.exception(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
@@ -465,7 +509,7 @@ def connector_permission_sync_generator_task(
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str | None,
tenant_id: str,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
@@ -473,6 +517,8 @@ def update_external_document_permissions_task(
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@@ -512,18 +558,33 @@ def update_external_document_permissions_task(
f"elapsed={elapsed:.2f}"
)
except Exception:
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
)
task_logger.exception(
f"Exception in update_external_document_permissions_task: "
f"update_external_document_permissions_task exceptioned: "
f"connector_id={connector_id} doc_id={doc_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
finally:
task_logger.info(
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
)
return True
def validate_permission_sync_fences(
tenant_id: str | None,
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -570,7 +631,7 @@ def validate_permission_sync_fences(
def validate_permission_sync_fence(
tenant_id: str | None,
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
@@ -780,7 +841,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)

View File

@@ -37,8 +37,11 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -55,6 +58,7 @@ from onyx.redis.redis_connector_ext_group_sync import (
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -119,7 +123,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client()
@@ -148,7 +152,10 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session, source, only_sync=True
db_session,
source,
access_type=AccessType.SYNC,
status=ConnectorCredentialPairStatus.ACTIVE,
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
@@ -195,12 +202,17 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_external_group_sync exception: tenant={tenant_id} {error_msg}"
)
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_external_group_sync finished: tenant={tenant_id}")
return True
@@ -208,7 +220,7 @@ def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
tenant_id: str,
) -> str | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
@@ -267,12 +279,19 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.set_fence(payload)
payload_id = payload.id
except Exception:
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_external_group_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
task_logger.info(
f"try_creating_external_group_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -287,7 +306,7 @@ def try_creating_external_group_sync_task(
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
tenant_id: str,
) -> None:
"""
External group sync task for a given connector credential pair
@@ -361,12 +380,36 @@ def connector_external_group_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -378,8 +421,18 @@ def connector_external_group_sync_generator_task(
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
@@ -405,6 +458,14 @@ def connector_external_group_sync_generator_task(
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id} {error_msg}"
)
task_logger.exception(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
)
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
@@ -432,7 +493,7 @@ def connector_external_group_sync_generator_task(
def validate_external_group_sync_fences(
tenant_id: str | None,
tenant_id: str,
celery_app: Celery,
r: Redis,
r_replica: Redis,
@@ -464,7 +525,7 @@ def validate_external_group_sync_fences(
def validate_external_group_sync_fence(
tenant_id: str | None,
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,

View File

@@ -23,9 +23,10 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.indexing.utils import _should_index
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
@@ -48,7 +49,7 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -61,7 +62,7 @@ from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_index_swap
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.redis.redis_connector import RedisConnector
@@ -182,7 +183,7 @@ class SimpleJobResult:
class ConnectorIndexingContext(BaseModel):
tenant_id: str | None
tenant_id: str
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
@@ -210,7 +211,7 @@ class ConnectorIndexingLogBuilder:
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
@@ -358,7 +359,7 @@ def monitor_ccpair_indexing_taskset(
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
@@ -406,7 +407,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# check for search settings swap
with get_session_with_current_tenant() as db_session:
old_search_settings = check_index_swap(db_session=db_session)
old_search_settings = check_and_perform_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
@@ -439,6 +440,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
with get_session_with_current_tenant() as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
# skip non-live search settings that don't have background reindex enabled
# those should just auto-change to live shortly after creation without
# requiring any indexing till that point
if (
not search_settings_instance.status.is_current()
and not search_settings_instance.background_reindex_enabled
):
continue
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -456,23 +466,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
if not should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if search_settings_instance.status.is_current():
# the indexing trigger is only checked and cleared with the current search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
@@ -598,7 +603,7 @@ def connector_indexing_task(
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -890,7 +895,7 @@ def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
tenant_id: str,
) -> None:
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
@@ -899,6 +904,9 @@ def connector_indexing_proxy_task(
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
start = time.monotonic()
@@ -924,6 +932,7 @@ def connector_indexing_proxy_task(
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting connector_indexing_task with tenant_id={tenant_id}")
job = client.submit(
connector_indexing_task,
@@ -976,6 +985,9 @@ def connector_indexing_proxy_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
@@ -1016,7 +1028,24 @@ def connector_indexing_proxy_task(
job.release()
break
# if a termination signal is detected, clean up and break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
@@ -1025,6 +1054,7 @@ def connector_indexing_proxy_task(
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if activity timeout is detected, break (exit point will clean up)
if not redis_connector_index.connector_active():
task_logger.warning(
log_builder.build(
@@ -1033,25 +1063,6 @@ def connector_indexing_proxy_task(
)
)
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
@@ -1070,15 +1081,15 @@ def connector_indexing_proxy_task(
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
@@ -1139,8 +1150,6 @@ def connector_indexing_proxy_task(
"Connector termination signal detected",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
@@ -1148,6 +1157,25 @@ def connector_indexing_proxy_task(
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
@@ -1163,11 +1191,12 @@ def connector_indexing_proxy_task(
return
# primary
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
)
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
"""Clean up old checkpoints that are older than 7 days."""
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
@@ -1210,6 +1239,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
)
# light worker
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,

View File

@@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase):
def validate_indexing_fence(
tenant_id: str | None,
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
@@ -311,7 +311,7 @@ def validate_indexing_fence(
def validate_indexing_fences(
tenant_id: str | None,
tenant_id: str,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
@@ -346,11 +346,10 @@ def validate_indexing_fences(
return
def _should_index(
def should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -415,9 +414,9 @@ def _should_index(
):
return False
if search_settings_primary:
if search_settings_instance.status.is_current():
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
# if a manual indexing trigger is on the cc pair, honor it for live search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
@@ -442,7 +441,7 @@ def try_creating_indexing_task(
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.

View File

@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
trail=False,
bind=True,
)
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
if not LLM_MODEL_UPDATE_API_URL:
raise ValueError("LLM model update API URL not configured")

View File

@@ -91,7 +91,7 @@ class Metric(BaseModel):
}
task_logger.info(json.dumps(data))
def emit(self, tenant_id: str | None) -> None:
def emit(self, tenant_id: str) -> None:
# Convert value to appropriate type based on the input value
bool_value = None
float_value = None
@@ -656,7 +656,7 @@ def build_job_id(
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
"""Collect and emit metrics about background processes.
This task runs periodically to gather metrics about:
- Queue lengths for different Celery queues
@@ -864,7 +864,7 @@ def cloud_monitor_celery_queues(
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
return monitor_celery_queues_helper(self)

View File

@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up

View File

@@ -55,6 +55,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@@ -113,7 +114,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
@@ -194,12 +195,14 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(f"Unexpected pruning check exception: {error_msg}")
task_logger.exception("Unexpected exception during pruning check")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_pruning finished: tenant={tenant_id}")
return True
@@ -208,7 +211,7 @@ def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str | None,
tenant_id: str,
) -> str | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
@@ -301,13 +304,19 @@ def try_creating_prune_generator_task(
redis_connector.prune.set_fence(payload)
payload_id = payload.id
except Exception:
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_prune_generator_task exception: cc_pair={cc_pair.id} {error_msg}"
)
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_prune_generator_task finished: cc_pair={cc_pair.id} payload_id={payload_id}"
)
return payload_id
@@ -324,7 +333,7 @@ def connector_pruning_generator_task(
cc_pair_id: int,
connector_id: int,
credential_id: int,
tenant_id: str | None,
tenant_id: str,
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -512,7 +521,7 @@ def connector_pruning_generator_task(
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -558,7 +567,7 @@ def monitor_ccpair_pruning_taskset(
def validate_pruning_fences(
tenant_id: str | None,
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -606,7 +615,7 @@ def validate_pruning_fences(
def validate_pruning_fence(
tenant_id: str | None,
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
queued_tasks: set[str],

View File

@@ -32,7 +32,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str | None,
tenant_id: str,
chunk_count: int | None,
) -> int:
return self.index.delete_single(
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str | None,
tenant_id: str,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:

View File

@@ -1,4 +1,5 @@
import time
from enum import Enum
from http import HTTPStatus
import httpx
@@ -45,6 +46,24 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
class OnyxCeleryTaskCompletionStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SKIPPED = "skipped"
SOFT_TIME_LIMIT = "soft_time_limit"
NON_RETRYABLE_EXCEPTION = "non_retryable_exception"
RETRYABLE_EXCEPTION = "retryable_exception"
@shared_task(
name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
@@ -57,7 +76,7 @@ def document_by_cc_pair_cleanup_task(
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str | None,
tenant_id: str,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
@@ -78,6 +97,8 @@ def document_by_cc_pair_cleanup_task(
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
action = "skip"
@@ -110,6 +131,9 @@ def document_by_cc_pair_cleanup_task(
db_session=db_session,
document_ids=[document_id],
)
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
elif count > 1:
action = "update"
@@ -153,10 +177,11 @@ def document_by_cc_pair_cleanup_task(
)
mark_document_as_synced(document_id, db_session)
else:
pass
db_session.commit()
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
else:
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
elapsed = time.monotonic() - start
task_logger.info(
@@ -168,57 +193,79 @@ def document_by_cc_pair_cleanup_task(
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"document_by_cc_pair_cleanup_task exceptioned: doc={document_id}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
return False
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(f"Unexpected exception: doc={document_id}")
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"document_by_cc_pair_cleanup_task completed: status={completion_status.value} doc={document_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(f"document_by_cc_pair_cleanup_task finished: doc={document_id}")
return True
@@ -250,7 +297,8 @@ def cloud_beat_task_generator(
return None
last_lock_time = time.monotonic()
tenant_ids: list[str] | list[None] = []
tenant_ids: list[str] = []
num_processed_tenants = 0
try:
tenant_ids = get_all_tenant_ids()
@@ -278,6 +326,8 @@ def cloud_beat_task_generator(
expires=expires,
ignore_result=True,
)
num_processed_tenants += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -297,6 +347,7 @@ def cloud_beat_task_generator(
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)

View File

@@ -19,6 +19,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -75,7 +76,7 @@ logger = setup_logger()
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
@@ -207,7 +208,7 @@ def try_generate_stale_document_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
# the fence is up, do nothing
@@ -283,7 +284,7 @@ def try_generate_document_set_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
lock_beat.reacquire()
@@ -360,7 +361,7 @@ def try_generate_user_group_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
lock_beat.reacquire()
@@ -447,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -522,11 +523,11 @@ def monitor_document_set_taskset(
time_limit=LIGHT_TIME_LIMIT,
max_retries=3,
)
def vespa_metadata_sync_task(
self: Task, document_id: str, *, tenant_id: str | None
) -> bool:
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
@@ -540,75 +541,103 @@ def vespa_metadata_sync_task(
doc = get_document(document_id, db_session)
if not doc:
return False
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=no_operation "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
else:
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"vespa_metadata_sync_task exceptioned: doc={document_id}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
return False
task_logger.exception(
f"Unexpected exception during vespa metadata sync: doc={document_id}"
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"vespa_metadata_sync_task completed: status={completion_status.value} doc={document_id}"
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
return True

View File

@@ -1,3 +1,5 @@
from sqlalchemy.exc import IntegrityError
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_current_tenant
@@ -9,5 +11,27 @@ def emit_background_error(
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)
error_message = ""
# try to write to the db, but handle IntegrityError specifically
try:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = (
f"Failed to create background error: {str(e)}. Original message: {message}"
)
except Exception:
pass
if not error_message:
return
# if we get here from an IntegrityError, try to write the error message to the db
# we need a new session because the first session is now invalid
try:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, error_message, None)
except Exception:
pass

View File

@@ -16,7 +16,10 @@ from typing import Optional
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.setup import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -54,6 +57,15 @@ def _initializer(
kwargs = {}
logger.info("Initializing spawned worker child process.")
# 1. Get tenant_id from args or fallback to default
tenant_id = POSTGRES_DEFAULT_SCHEMA
for arg in reversed(args):
if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX):
tenant_id = arg
break
# 2. Set the tenant context before running anything
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Reset the engine in the child process
SqlEngine.reset_engine()
@@ -81,6 +93,8 @@ def _initializer(
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use 255 to indicate a generic exception
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _run_in_process(

Some files were not shown because too many files have changed in this diff Show More