Compare commits

...

144 Commits

Author SHA1 Message Date
Richard Kuo (Danswer)
91e32e801d hope this env var works. 2025-01-09 13:51:58 -08:00
hagen-danswer
d40fd82803 Conf doc sync improvements (#3643)
* Reduce number of requests to Confluence

* undo

* added a way to dynamically adjust the pagination limit

* undo
2025-01-09 12:56:56 -08:00
rkuo-danswer
97a963b4bf add index to speed up get last attempt (#3636)
* add index to speed up get last attempt

* use descending order

* put back unique param

* how did this not get formatted?

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-09 00:56:55 +00:00
pablonyx
7f6ef1ff57 Remove unnecessary logspam
Remove unnecessary logs
2025-01-08 17:03:52 -08:00
pablodanswer
d98746b988 remove unnecessary logs 2025-01-08 17:03:15 -08:00
rkuo-danswer
a76f1b4c1b Merge pull request #3628 from onyx-dot-app/bugfix/debug_tenant
add more debug logging for locking issue
2025-01-08 15:14:37 -08:00
hagen-danswer
4c4ff46fe3 Fixing google drive tests (#3634)
* Fixing google drive texts

* Update conftest.py
2025-01-08 22:34:38 +00:00
hagen-danswer
0f9842064f Added env var to skip warm up (#3633) 2025-01-08 14:29:15 -08:00
pablonyx
d7bc32c0ec Fully remove visit API (#3621)
* v1

* update indexing logic

* update updates

* nit

* clean up args

* update for clarity + best practices

* nit + logs

* fix

* minor clean up

* remove logs

* quick nit
2025-01-08 13:49:01 -08:00
Richard Kuo (Danswer)
1f48de9731 more logging 2025-01-08 12:49:24 -08:00
Richard Kuo (Danswer)
a22d02ff70 add another log line 2025-01-08 10:01:24 -08:00
Richard Kuo (Danswer)
dcfc621a66 add more debug logging for locking issue 2025-01-08 09:43:47 -08:00
Chris Weaver
eac73a1bf1 Improve egnyte connector (#3626) 2025-01-08 03:09:46 +00:00
pablonyx
717560872f Merge pull request #3627 from onyx-dot-app/whitelabeling_name
Whitelabelling
2025-01-07 19:16:01 -08:00
pablodanswer
ce2572134c k 2025-01-07 19:06:52 -08:00
rkuo-danswer
02f72a5c86 Multiple cloud/indexing fixes (#3609)
* more debugging

* test reacquire outside of loop

* more logging

* move lock_beat test outside the try catch so that we don't worry about testing locks we never took

* use a larger scan_iter value for performance

* batch stale document sync batches

* add debug logging for a particular timeout issue

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-08 01:30:29 +00:00
hagen-danswer
eb916df139 added debugger step 2025-01-07 16:18:46 -08:00
hagen-danswer
fafad5e119 Improve contributing guide (#3625)
* Improve contributing guide

* more improvements to contributing guide
2025-01-07 16:16:17 -08:00
pablonyx
a314a08309 Speed up admin pages (#3623)
* ni

* speed up pages

* minor nit

* nit
2025-01-07 15:40:26 -08:00
hagen-danswer
4ce24d68f7 prevent other tests from interfering with existing google drive tests (#3624)
* prevent other tests from interfering with existing google drive tests

* cleanup gdrive tests

* finished

* done
2025-01-07 15:32:36 -08:00
hagen-danswer
a95f4298ad Improved logging for confluence calls (#3622)
* Improved logging for confluence calls

* cleanup

* idk

* combined logging
2025-01-07 21:53:08 +00:00
rkuo-danswer
7cd76ec404 comment out the per doc sync hack (#3620)
* comment out the per doc sync hack

* fix commented code

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-07 19:44:15 +00:00
pablonyx
5b5c1166ca Async Redis (#3618)
* k

* update configs for clarity

* typing

* update
2025-01-07 19:34:57 +00:00
pablonyx
d9e9c6973d Multitenant anonymous (#3595)
* anonymous users for multi tenant setting

* nit

* k
2025-01-07 02:57:20 +00:00
pablonyx
91903141cd Built in tool cache with tool call id (#3617)
* k

* improved

* k

* nit

* nit

* nit
2025-01-07 01:03:52 +00:00
hagen-danswer
e329b63b89 Added Permission Syncing for Salesforce (#3551)
* Added Permission Syncing for Salesforce

* cleanup

* updated connector doc conversion

* finished salesforce permission syncing

* fixed connector to batch Salesforce queries

* tests!

* k

* Added error handling and check for ee and sync type for postprocessing

* comments

* minor touchups

* tested to work!

* done

* my pie

* lil cleanup

* minor comment
2025-01-07 00:37:03 +00:00
hagen-danswer
71c2559ea9 Discord cleanup (#3615)
* Discord cleanup

* fix case discrepancy
2025-01-06 15:11:03 -08:00
Ishankoradia
ceb34a41d9 discord connector (#3023)
* discord: frontend and backend poll connector

* added requirements for discord installation

* fixed the mypy errors

* process messages not part of any thread

* minor change

* updated the connector; this logic works & am able to docs when i print

* minor change

* ability to enter a start date to pull docs from and refactor

* added the load connector and fixed mypy errors

* local commit test

done!

* minor refactor and properly commented everything

* updated the logic to handle permissions and index active/archived threads

* basic discord test template

* cleanup

* going away with the danswer discord client class ; using an async context manager

* moved to proper folder

* minor fixes

* needs improvement

* fixed discord icon

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2025-01-06 14:54:22 -08:00
pablonyx
82eab9d704 Doc explore fix (#3614)
* k

* k

* add comment
2025-01-06 19:42:07 +00:00
pablonyx
2b8d3a6ef5 fix white labelling empty string (#3603) 2025-01-06 19:26:55 +00:00
pablonyx
4fb129e77b Increase timeout + revert changes for clarity (#3604)
* increase timeout + revert changes for clarity

* quick nit

* k
2025-01-06 18:20:53 +00:00
pablonyx
f16ca1b735 minor auth fix (#3613) 2025-01-06 17:51:02 +00:00
pablonyx
e3b2c9d944 Tracking update (#3605)
* tracking update

* k
2025-01-06 17:17:00 +00:00
pablodanswer
6c9c25642d remove empty files on main 2025-01-06 09:01:33 -08:00
hagen-danswer
2862d8bbd3 Minor opensource cleanup (#3610) 2025-01-06 07:26:07 -08:00
skylares
143be6a524 Add assertions to Zendesk connector tests (#3600)
Co-authored-by: hagen-danswer <hagen@danswer.ai>
2025-01-06 06:43:23 -08:00
SubashMohan
c2444a5cff Slim connector for Zendesk (#3367)
* Add SlimConnector support for Zendesk

* ZenDesk format changes

* code formating

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2025-01-06 06:41:41 -08:00
rkuo-danswer
7f8194798a Merge pull request #3608 from onyx-dot-app/bugfix/locking_4
fix timing calculations and don't spam the queue lengths check from e…
2025-01-05 21:13:47 -08:00
Richard Kuo (Danswer)
e3947e4b64 fix timing calculations and don't spam the queue lengths check from every task 2025-01-05 21:13:08 -08:00
rkuo-danswer
98005510ad Merge pull request #3607 from onyx-dot-app/bugfix/locking_redux
add detailed timings to monitor vespa sync
2025-01-05 20:03:17 -08:00
Richard Kuo
ca54bd0b21 add logging 2025-01-05 20:02:05 -08:00
Richard Kuo
d26f8ce852 add detailed timings to monitor vespa sync 2025-01-05 19:24:58 -08:00
pablonyx
c8090ab75b Auth fix + Registration Clarity (#3590)
* clarify auth flow

* k

* nit

* k

* fix typing
2025-01-06 02:17:45 +00:00
hagen-danswer
e100a5e965 Properly account for anonymous access in Confluence 2025-01-05 18:01:39 -08:00
pablonyx
ddec239fef Improved indexing (#3594)
* nit

* k

* add steps

* main util functions

* functioning fully

* quick nit

* k

* typing fix

* k

* address comments
2025-01-05 23:31:53 +00:00
Chris Weaver
e83542f572 Add support for auto-refreshing available models based on an API call (#3576) 2025-01-05 15:45:49 -08:00
joachim-danswer
8750f14647 alignment & renaming of objects for initial (displayed) ranking and re-ranking/validation citations
- renamed post-reranking/validation citation information consistently to final_... (example: doc_id_to_rank_map -> final_doc_id_to_rank_map)
 - changed and renamed objects containing initial ranking information (now: display_...) consistent with final rankings (final_...). Specifically, {} to [] for displayed_search_results
 - for CitationInfo, changed citation_num from 'x-th citation in response stream' to the initial position of the doc [NOTE: test implications]
-  changed tests:
    onyx/backend/tests/unit/onyx/chat/stream_processing/test_citation_processing.py
    onyx/backend/tests/unit/onyx/chat/stream_processing/test_citation_substitution.py
2025-01-05 15:44:34 -08:00
rkuo-danswer
27699c8216 Merge pull request #3602 from onyx-dot-app/bugfix/lock_not_owned
various lock diagnostics and timing adjustments
2025-01-05 15:12:01 -08:00
Richard Kuo (Danswer)
6fcd712a00 comment for lock and usage 2025-01-05 15:11:39 -08:00
Richard Kuo (Danswer)
b027a08698 various lock diagnostics and timing adjustments 2025-01-05 13:59:36 -08:00
Chris Weaver
1db778baa8 Small airtable refactor + handle files with uppercase extensions (#3598)
* Small airtable refactor + handle files with uppercase extensions

* Fix mypy
2025-01-05 11:27:50 -08:00
Chris Weaver
f895e5f7d0 Speedup orphan doc cleanup script (#3596)
* Speedup orphan doc cleanup script

* Fix mypy
2025-01-05 14:28:25 +00:00
rkuo-danswer
2fc58252f4 Merge pull request #3599 from onyx-dot-app/bugfix/doc-sync
quick hack to prevent resyncing the same doc
2025-01-05 04:14:28 -08:00
Richard Kuo (Danswer)
371d1ccd8f move to just setting keys 2025-01-05 03:26:33 -08:00
Richard Kuo (Danswer)
7fb92d42a0 quick hack to prevent resyncing the same doc 2025-01-05 03:05:32 -08:00
Weves
af2061c4db Add Linear OAuth env variables to dev compose 2025-01-04 16:02:04 -08:00
pablonyx
ffec19645b JWT -> Redis (#3574)
* functional v1

* functional logout

* minor clean up

* quick clean up

* update configuration

* ni

* nit

* finalize

* update login page

* delete unused import

* quick nit

* updates

* clean up

* ni

* k

* k
2025-01-04 19:35:43 +00:00
pablonyx
67d2c86250 Remove Exclamation marks + comments (#3586)
* remove explanation marks + comments

* nit
2025-01-04 18:39:16 +00:00
pablonyx
6c018cb53f add personal assistant usage stats (#3543) 2025-01-04 18:38:41 +00:00
pablonyx
62302e3faf Latex for $10 and $100 (#3585)
* nit

* k
2025-01-04 17:37:04 +00:00
rkuo-danswer
0460531c72 Merge pull request #3593 from onyx-dot-app/bugfix/primary_worker_lock
the primary worker lock doesn't always exist
2025-01-03 22:01:11 -08:00
Richard Kuo
6af07a888b the primary worker lock doesn't always exist 2025-01-03 21:12:11 -08:00
Weves
ea75f5cd5d Move google-cloud-aiplatform to default requirements to support vertex AI llama 2025-01-03 15:20:59 -08:00
rkuo-danswer
b92c183022 re-prep user group deletion on the actual deletion (#3588)
* re-prep user group deletion on the actual deletion

* user group needs to be synced to be prepped

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-03 22:35:56 +00:00
skylares
c191e23256 Pagination Hook (#3494)
* Backend changes for pagination hook + Paginated users table

* Frontend changes for pagination hook

* Fix invited users endpoint

* Fix layout shift & add enter to submit user invites

* mypy

* Cleanup

* Resolve PR concerns & remove UserStatus

* Fix errors

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2025-01-03 14:32:55 -08:00
rkuo-danswer
66f9124135 Merge pull request #3584 from onyx-dot-app/bugfix/log_spacing
fix formatting
2025-01-03 00:43:36 -08:00
Richard Kuo
8f0fb70bbf fix formatting 2025-01-02 23:21:54 -08:00
rkuo-danswer
ef5e5c80bb Merge pull request #3577 from onyx-dot-app/bugfix/model_server_exception_logging
fix response logging
2025-01-02 23:08:46 -08:00
rkuo-danswer
03acb6587a Feature/model server logging (#3579)
* improve model server logging

* improve exception logging with provider/model names

* get everything into one log line

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-01-03 01:40:29 +00:00
hagen-danswer
d1ec72b5e5 Reworked salesforce connector to use bulk api (#3581) 2025-01-02 18:09:02 -08:00
Weves
3b214133a8 Airtable improvement 2025-01-02 17:56:05 -08:00
rkuo-danswer
2232702e99 retry the individual delete's (#3580)
* retry the individual delete's

* need to raise inside the retry

* just use retry for now

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-01-02 17:39:37 -08:00
hagen-danswer
8108ff0a4b Added logging for permissions upsert queue length 2025-01-02 17:39:01 -08:00
Richard Kuo
f64e78e986 fix response logging 2025-01-02 13:39:19 -08:00
Chris Weaver
08312a4394 Update Slack link in README.md 2025-01-01 10:03:59 -08:00
Weves
92add655e0 Slack fixes 2024-12-31 18:04:12 -08:00
Chris Weaver
d64464ca7c Add support for OAuth connectors that require user input (#3571)
* Add support for OAuth connectors that require user input

* Cleanup

* Fix linear

* Small re-naming

* Remove console.log
2024-12-31 18:03:33 -08:00
Yuhong Sun
ccd3983802 Linear OAuth Connector (#3570) 2024-12-31 16:11:09 -08:00
pablonyx
240f3e4fff Ensure users cannot modify their roles on main
Ensure users cannot modify their roles
2024-12-31 15:59:27 -05:00
pablonyx
1291b3d930 Add anonymous user to main
Anonymous user
2024-12-31 15:58:52 -05:00
rkuo-danswer
d05f1997b5 Merge pull request #3569 from onyx-dot-app/bugfix/alt_index
we didn't want to rename the alt index suffix, reverting
2024-12-31 12:39:00 -08:00
Chris Weaver
aa2e2a62b9 Small Egnyte tweaks (#3568) 2024-12-31 19:28:38 +00:00
Richard Kuo
174e5968f8 we didn't want to rename the alt index suffix, reverting 2024-12-31 11:28:11 -08:00
pablodanswer
1f27606e17 minor clean up 2024-12-31 13:04:02 -05:00
pablodanswer
60355b84c1 quick nits 2024-12-31 13:04:02 -05:00
pablodanswer
680ab9ea30 updated logic 2024-12-31 13:04:02 -05:00
pablodanswer
c2447dbb1c cosmetic updates 2024-12-31 13:04:02 -05:00
pablodanswer
52bad522f8 update for multi-tenant clarity 2024-12-31 13:04:02 -05:00
pablodanswer
63e5e58313 add anonymous user 2024-12-31 13:04:02 -05:00
rkuo-danswer
2643782e30 Merge pull request #3567 from onyx-dot-app/bugfix/revert_vespa
Revert "More efficient Vespa indexing (#3552)"
2024-12-31 09:47:00 -08:00
Richard Kuo
3eb72e5c1d Revert "More efficient Vespa indexing (#3552)"
This reverts commit 2783216781.
2024-12-31 09:40:23 -08:00
rkuo-danswer
9b65c23a7e Merge pull request #3566 from onyx-dot-app/bugfix/primary_task_timings
re-enable celery task execution logging in primary worker
2024-12-31 01:29:05 -08:00
Richard Kuo (Danswer)
b43a8e48c6 add some return types to distinguish when the task is actually performing work 2024-12-31 00:10:33 -08:00
Richard Kuo (Danswer)
1955c1d67b re-enable celery task execution logging in primary worker 2024-12-30 21:53:00 -08:00
Chris Weaver
3f92ed9d29 Airtable connector (#3564)
* Airtable connector

* Improvements

* improve

* Clean up test

* Add secrets

* Fix mypy + add access token

* Mock unstructured call

* Adjust comments

* Fix ID in test
2024-12-31 03:06:28 +00:00
Weves
618369f4a1 Small fix 2024-12-30 19:20:30 -08:00
pablonyx
2783216781 More efficient Vespa indexing (#3552)
---------

Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com>
2024-12-30 18:51:14 -08:00
rkuo-danswer
bec0f9fb23 permission sync in cloud and beat expiry adjustment (#3544)
* try fixing exception in cloud

* raise beat expiry ... 60 seconds might be starving certain tasks completely

* adjust expiry down to 10 min

* raise concurrency overflow for indexing worker.

* parent pid check

* fix comment

* fix parent pid check, also actually raise an exception from the task if the spawned task exit status is bad

* fix pid check

* some cleanup and task wait fixes

* review fixes

* comment some code so we don't change too many things at once

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-12-31 01:05:57 +00:00
pablodanswer
97a03e7fc8 nit 2024-12-29 21:07:12 -05:00
pablodanswer
8d6e8269b7 k 2024-12-29 21:07:12 -05:00
pablodanswer
9ce2c6c517 minor change 2024-12-29 21:07:12 -05:00
pablodanswer
2ad8bdbc65 k 2024-12-29 21:07:12 -05:00
rkuo-danswer
a83c9b40d5 Bugfix/oauth fix (#3507)
* old oauth file left behind

* fix function change that was lost in merge

* fix some testing vars

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-30 01:49:12 +00:00
Chris Weaver
340fab1375 Additional error handling + logging for google drive connector (#3563)
* Additional error handling + logging for google drive connector

* Fix mypy
2024-12-29 17:48:02 -08:00
hagen-danswer
3ec338307f Fixed indexing issues with Salesforce 2024-12-29 16:45:29 -08:00
pablonyx
27acd3387a Auth specific rate limiting (#3463)
* k

* v1

* fully functional

* finalize

* nit

* nit

* nit

* clean up with wrapper + comments

* k

* update

* minor clean
2024-12-29 23:34:23 +00:00
hagen-danswer
d14ef431a7 Improve Salesforce connector 2024-12-29 14:18:40 -08:00
pablonyx
9bffeb65af Eagerly load CCpair connectors (#3531)
* remove left over vim command

* eager loading

* Revert "remove left over vim command"

This reverts commit 184a134ae0.
2024-12-29 15:58:38 +00:00
Yuhong Sun
f4806da653 Fix Null Value in PG (#3559)
* k

* k

* k

* k

* k
2024-12-29 01:53:16 +00:00
pablonyx
e2700b2bbd Remove left over yaml errors (#3527)
* remove left over vim command

* additional misconfigurations

* ensure all regions updated
2024-12-29 01:45:07 +00:00
Yuhong Sun
fc81a3fb12 Zendesk Retries (#3558)
* k

* k

* k

* k
2024-12-28 23:51:49 +00:00
pablonyx
2203cfabea Prevent SSRF risk
Prevent SSRF risk
2024-12-28 15:25:57 -05:00
pablodanswer
f4050306d6 Prevent SSRF risk 2024-12-28 15:25:12 -05:00
Weves
2d960a477f Fix discourse connector 2024-12-24 12:43:10 -08:00
hagen-danswer
8837b8ea79 Curators can now update the curator relationship (#3536)
* Curators can now update the curator relationship

* mypy

* mypy

* whoops haha
2024-12-24 18:49:58 +00:00
hagen-danswer
3dfb214f73 Slackbot polish (#3547) 2024-12-24 16:19:15 +00:00
pablonyx
18d7262608 fix logo rendering (#3542) 2024-12-22 23:00:33 +00:00
pablonyx
09b879ee73 Ensure gmail works for personal accounts (#3541)
* Ensure gmail works for personal accounts

* nit

* minor update
2024-12-22 23:00:14 +00:00
rkuo-danswer
aaa668c963 Merge pull request #3534 from onyx-dot-app/bugfix/validate_ttl
raise activity timeout to one hour
2024-12-22 13:13:57 -08:00
pablonyx
edb877f4bc fix NUL character (#3540) 2024-12-21 23:30:25 +00:00
rkuo-danswer
eb369caefb log attempt id, log elapsed since task execution start, remove log spam (#3539)
* log attempt id, log elapsed since task execution start, remove log spam

* diagnostic lock logs

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-21 23:03:50 +00:00
Chris Weaver
b9567eabd7 Fix bedrock w/ access keys (#3538)
* Fix bedrock w/ access keys

* cleanup

* Remove extra #
2024-12-21 02:24:11 +00:00
Richard Kuo (Danswer)
13bbf67091 raise activity timeout to one hour 2024-12-20 16:18:35 -08:00
hagen-danswer
457a4c73f0 Made sure confluence connector recursive by page includes top level page (#3532)
* Made sure confluence connector by page includes top level page

* surface level change
2024-12-20 21:53:59 +00:00
rkuo-danswer
ce37688b5b allow limited user to create chat session (#3533)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-20 21:36:41 +00:00
pablonyx
4e2c90f4af Proper user deletion / organization leaving (#3460)
* Proper user deletion / organization leaving

* minor nit

* update

* udpate provisioning

* minor cleanup

* typing

* post rebase
2024-12-20 21:01:03 +00:00
pablonyx
513dd8a319 update toggling states (#3519) 2024-12-20 20:27:22 +00:00
hagen-danswer
71c5043832 Added filter to exclude attachments with unsupported file extensions (#3530)
* Added filter to exclude attachments with unsupported file extensions

* extension
2024-12-20 19:48:15 +00:00
pablonyx
64b6f15e95 AWS extraneous error fix (#3529)
* remove left over vim command

* aws fix

* k

* remove double
2024-12-20 19:31:04 +00:00
hagen-danswer
35022f5f09 Fix group table (#3523) 2024-12-20 17:51:26 +00:00
hagen-danswer
0d44014c16 Cleanup PR template to make it more concise (#3524) 2024-12-20 17:49:31 +00:00
Yuhong Sun
1b9e9f48fa Update README.md 2024-12-20 10:26:57 -08:00
Yuhong Sun
05fb5aa27b Update README.md 2024-12-20 10:25:34 -08:00
Yuhong Sun
3b645b72a3 Crop Logo Closer 2024-12-20 10:23:52 -08:00
Yuhong Sun
fe770b5c3a Fix Logo On DarkMode (#3525) 2024-12-20 10:15:48 -08:00
hagen-danswer
1eaf885f50 associating credentials with connectors is not considered editing (#3522)
* associating credentials with connectors is not considered editing

* formatting

* formatting

* Update credentials.py

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-12-20 17:36:25 +00:00
rkuo-danswer
a187aa508c use redis exclusively with active signal renewal in more places to perform indexing fence validation (#3517)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-20 06:54:00 +00:00
pablonyx
aa4bfa2a78 Forgot password feature (#3437)
* forgot password feature

* improved config

* nit

* nit
2024-12-20 04:53:37 +00:00
pablonyx
9011b8a139 Update citations in shared chat display (#3487)
* update shared chat display

* Change Copy

* fix icon

* remove secret!

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-12-20 01:48:29 +00:00
pablonyx
59c774353a Latex formatting (#3499) 2024-12-19 14:48:06 -08:00
pablonyx
b458d504af Sidebar Default Open (#3488) 2024-12-19 14:04:50 -08:00
Yuhong Sun
f83e7bfcd9 Fix Default CC Pair (#3513) 2024-12-19 09:43:12 -08:00
pablonyx
4d2e26ce4b MT Cloud Tracking Fix (#3514) 2024-12-19 08:47:02 -08:00
pablonyx
817fdc1f36 Ensure metadata overrides file contents (#3512)
* ensure metadata overrides file contents

* update more blocks
2024-12-19 04:44:24 +00:00
320 changed files with 13785 additions and 3852 deletions

View File

@@ -6,24 +6,6 @@
[Describe the tests you ran to verify your changes]
## Accepted Risk (provide if relevant)
N/A
## Related Issue(s) (provide if relevant)
N/A
## Mental Checklist:
- All of the automated tests pass
- All PR comments are addressed and marked resolved
- If there are migrations, they have been rebased to latest main
- If there are new dependencies, they are added to the requirements
- If there are new environment variables, they are added to all of the deployment methods
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- Docker images build and basic functionalities work
- Author has done a final read through of the PR right before merge
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)

View File

@@ -66,6 +66,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -118,6 +118,6 @@ jobs:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout: "10m"

View File

@@ -26,7 +26,19 @@ env:
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
# Zendesk
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
# Salesforce
SF_USERNAME: ${{ secrets.SF_USERNAME }}
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -5,6 +5,8 @@
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# Skip warm up for dev
SKIP_WARM_UP=True
# Always keep these on for Dev
# Logs all model prompts to stdout

View File

@@ -355,5 +355,20 @@
"PYTHONPATH": "."
},
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
]
}

View File

@@ -12,6 +12,10 @@ As an open source project in a rapidly changing space, we welcome all contributi
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
To ensure that your contribution is aligned with the project's direction, please reach out to Hagen (or any other maintainer) on the Onyx team
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
will be marked with the `approved by maintainers` label.
Issues marked `good first issue` are an especially great place to start.
@@ -23,8 +27,8 @@ If you have a new/different contribution in mind, we'd love to hear about it!
Your input is vital to making sure that Onyx moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
And always feel free to message us (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ) /
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
### Contributing Code
@@ -42,7 +46,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
That way we can help future contributors and users can avoid the same issue.
We also have support channels and generally interesting discussions on our
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
and
[Discord](https://discord.gg/TDJ59cGV2X).
@@ -123,7 +127,47 @@ Once the above is done, navigate to `onyx/web` run:
npm i
```
#### Docker containers for external software
## Formatting and Linting
### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `onyx/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Onyx is fully type-annotated, and we want to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
# Running the application for development
## Developing using VSCode Debugger (recommended)
We highly recommend using VSCode debugger for development.
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
Otherwise, you can follow the instructions below to run the application for development.
## Manually running the application for development
### Docker containers for external software
You will need Docker installed to run these containers.
@@ -135,7 +179,7 @@ docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
#### Running Onyx locally
### Running Onyx locally
To start the frontend, navigate to `onyx/web` and run:
@@ -223,35 +267,6 @@ If you want to make changes to Onyx and run those changes in Docker, you can als
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
```
### Formatting and Linting
#### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `onyx/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Onyx is fully type-annotated, and we want to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
#### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
### Release Process

29
CONTRIBUTING_VSCODE.md Normal file
View File

@@ -0,0 +1,29 @@
# VSCode Debugging Setup
This guide explains how to set up and use VSCode's debugging capabilities with this project.
## Initial Setup
1. **Environment Setup**:
- Copy `.vscode/.env.template` to `.vscode/.env`
- Fill in the necessary environment variables in `.vscode/.env`
2. **launch.json**:
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
## Using the Debugger
Before starting, make sure the Docker Daemon is running.
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
6. Use the debug toolbar to step through code, inspect variables, etc.
## Features
- Hot reload is enabled for the web server and API servers
- Python debugging is configured with debugpy
- Environment variables are loaded from `.vscode/.env`
- Console output is organized in the integrated terminal with labeled tabs

View File

@@ -3,7 +3,7 @@
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
</h2>
<p align="center">
@@ -13,7 +13,7 @@
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -24,7 +24,7 @@
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
@@ -133,15 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)
## ✨Contributors
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

1
backend/.gitignore vendored
View File

@@ -9,3 +9,4 @@ api_keys.py
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*
onyx/connectors/salesforce/data/

View File

@@ -4,7 +4,7 @@ from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
@@ -120,7 +120,7 @@ def provide_iam_token_for_alembic(
) -> None:
if USE_IAM_AUTH:
# Database connection settings
region = AWS_REGION
region = AWS_REGION_NAME
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER

View File

@@ -0,0 +1,24 @@
"""add chunk count to document
Revision ID: 2955778aa44c
Revises: c0aab6edb6dd
Create Date: 2025-01-04 11:39:43.268612
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2955778aa44c"
down_revision = "c0aab6edb6dd"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("document", sa.Column("chunk_count", sa.Integer(), nullable=True))
def downgrade() -> None:
op.drop_column("document", "chunk_count")

View File

@@ -0,0 +1,35 @@
"""add composite index for index attempt time updated
Revision ID: 369644546676
Revises: 2955778aa44c
Create Date: 2025-01-08 15:38:17.224380
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "369644546676"
down_revision = "2955778aa44c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
text("time_updated DESC"),
],
unique=False,
)
def downgrade() -> None:
op.drop_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
table_name="index_attempt",
)

View File

@@ -0,0 +1,31 @@
"""mapping for anonymous user path
Revision ID: a4f6ee863c47
Revises: 14a83a331951
Create Date: 2025-01-04 14:16:58.697451
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "a4f6ee863c47"
down_revision = "14a83a331951"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"tenant_anonymous_user_path",
sa.Column("tenant_id", sa.String(), primary_key=True, nullable=False),
sa.Column("anonymous_user_path", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
sa.UniqueConstraint("anonymous_user_path"),
)
def downgrade() -> None:
op.drop_table("tenant_anonymous_user_path")

View File

@@ -3,6 +3,10 @@ from sqlalchemy.orm import Session
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from ee.onyx.db.user_group import fetch_user_groups_for_documents
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from onyx.access.access import (
_get_access_for_documents as get_access_for_documents_without_groups,
)
@@ -10,6 +14,7 @@ from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_gro
from onyx.access.models import DocumentAccess
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_group
from onyx.db.document import get_document_sources
from onyx.db.document import get_documents_by_ids
from onyx.db.models import User
@@ -52,9 +57,20 @@ def _get_access_for_documents(
)
doc_id_map = {doc.id: doc for doc in documents}
# Get all sources in one batch
doc_id_to_source_map = get_document_sources(
db_session=db_session,
document_ids=document_ids,
)
access_map = {}
for document_id, non_ee_access in non_ee_access_dict.items():
document = doc_id_map[document_id]
source = doc_id_to_source_map.get(document_id)
is_only_censored = (
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
and source not in DOC_PERMISSIONS_FUNC_MAP
)
ext_u_emails = (
set(document.external_user_emails)
@@ -70,7 +86,11 @@ def _get_access_for_documents(
# If the document is determined to be "public" externally (through a SYNC connector)
# then it's given the same access level as if it were marked public within Onyx
is_public_anywhere = document.is_public or non_ee_access.is_public
# If its censored, then it's public anywhere during the search and then permissions are
# applied after the search
is_public_anywhere = (
document.is_public or non_ee_access.is_public or is_only_censored
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(

View File

@@ -1,5 +1,7 @@
from datetime import datetime
from functools import lru_cache
import jwt
import requests
from fastapi import Depends
from fastapi import HTTPException
@@ -20,6 +22,7 @@ from ee.onyx.server.seeding import get_seed_config
from ee.onyx.utils.secrets import extract_hashed_cookie
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import AuthType
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -118,3 +121,17 @@ async def current_cloud_superuser(
detail="Access denied. User must be a cloud superuser to perform this action.",
)
return user
def generate_anonymous_user_jwt_token(tenant_id: str) -> str:
payload = {
"tenant_id": tenant_id,
# Token does not expire
"iat": datetime.utcnow(), # Issued at time
}
return jwt.encode(payload, USER_AUTH_SECRET, algorithm="HS256")
def decode_anonymous_user_jwt_token(token: str) -> dict:
return jwt.decode(token, USER_AUTH_SECRET, algorithms=["HS256"])

View File

@@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from ee.onyx.db.user_group import delete_user_group
from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import mark_user_group_as_synced
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from onyx.background.celery.apps.app_base import task_logger
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
@@ -46,11 +47,20 @@ def monitor_usergroup_taskset(
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
usergroup_name = user_group.name
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
)
rug.reset()

View File

@@ -15,6 +15,12 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
# This is a boolean that determines if anonymous access is public
# Default behavior is to not make the page public and instead add a group
# that contains all the users that we found in Confluence
CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true"
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
@@ -55,3 +61,5 @@ POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"

View File

@@ -2,6 +2,7 @@ import datetime
from collections.abc import Sequence
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import case
from sqlalchemy import cast
from sqlalchemy import Date
@@ -14,6 +15,9 @@ from onyx.configs.constants import MessageType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessageFeedback
from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import UserRole
def fetch_query_analytics(
@@ -234,3 +238,121 @@ def fetch_persona_unique_users(
)
return [tuple(row) for row in db_session.execute(query).all()]
def fetch_assistant_message_analytics(
db_session: Session,
assistant_id: int,
start: datetime.datetime,
end: datetime.datetime,
) -> list[tuple[int, datetime.date]]:
"""
Gets the daily message counts for a specific assistant in the given time range.
"""
query = (
select(
func.count(ChatMessage.id),
cast(ChatMessage.time_sent, Date),
)
.join(
ChatSession,
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
)
.group_by(cast(ChatMessage.time_sent, Date))
.order_by(cast(ChatMessage.time_sent, Date))
)
return [tuple(row) for row in db_session.execute(query).all()]
def fetch_assistant_unique_users(
db_session: Session,
assistant_id: int,
start: datetime.datetime,
end: datetime.datetime,
) -> list[tuple[int, datetime.date]]:
"""
Gets the daily unique user counts for a specific assistant in the given time range.
"""
query = (
select(
func.count(func.distinct(ChatSession.user_id)),
cast(ChatMessage.time_sent, Date),
)
.join(
ChatSession,
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
)
.group_by(cast(ChatMessage.time_sent, Date))
.order_by(cast(ChatMessage.time_sent, Date))
)
return [tuple(row) for row in db_session.execute(query).all()]
def fetch_assistant_unique_users_total(
db_session: Session,
assistant_id: int,
start: datetime.datetime,
end: datetime.datetime,
) -> int:
"""
Gets the total number of distinct users who have sent or received messages from
the specified assistant in the given time range.
"""
query = (
select(func.count(func.distinct(ChatSession.user_id)))
.select_from(ChatMessage)
.join(
ChatSession,
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
)
)
result = db_session.execute(query).scalar()
return result if result else 0
# Users can view assistant stats if they created the persona,
# or if they are an admin
def user_can_view_assistant_stats(
db_session: Session, user: User | None, assistant_id: int
) -> bool:
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
return True
# Check if the user created the persona
stmt = select(Persona).where(
and_(Persona.id == assistant_id, Persona.user_id == user.id)
)
persona = db_session.execute(stmt).scalar_one_or_none()
return persona is not None

View File

@@ -10,6 +10,7 @@ from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -106,3 +107,21 @@ def fetch_external_groups_for_user(
User__ExternalUserGroupId.user_id == user_id
)
).all()
def fetch_external_groups_for_user_email_and_group_ids(
db_session: Session,
user_email: str,
group_ids: list[str],
) -> list[User__ExternalUserGroupId]:
user = get_user_by_email(db_session=db_session, email=user_email)
if user is None:
return []
user_id = user.id
user_ext_groups = db_session.scalars(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id,
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
)
).all()
return list(user_ext_groups)

View File

@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
)
def validate_user_creation_permissions(
def validate_object_creation_for_user(
db_session: Session,
user: User | None,
target_group_ids: list[int] | None = None,
@@ -440,32 +440,108 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
_validate_curator_status__no_commit(db_session, [user])
def update_user_curator_relationship(
def _validate_curator_relationship_update_requester(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
user_making_change: User | None = None,
) -> None:
user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
"""
This function validates that the user making the change has the necessary permissions
to update the curator relationship for the target user in the given user group.
"""
if user.role == UserRole.ADMIN:
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
return
# check if the user making the change is a curator in the group they are changing the curator relationship for
user_making_change_curator_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user_making_change.id,
# only check if the user making the change is a curator if they are a curator
# otherwise, they are a global_curator and can update the curator relationship
# for any group they are a member of
only_curator_groups=user_making_change.role == UserRole.CURATOR,
)
requestor_curator_group_ids = [
group.id for group in user_making_change_curator_groups
]
if user_group_id not in requestor_curator_group_ids:
raise ValueError(
f"User '{user.email}' is an admin and therefore has all permissions "
f"user making change {user_making_change.email} is not a curator,"
f" admin, or global_curator for group '{user_group_id}'"
)
def _validate_curator_relationship_update_request(
db_session: Session,
user_group_id: int,
target_user: User,
) -> None:
"""
This function validates that the curator_relationship_update request itself is valid.
"""
if target_user.role == UserRole.ADMIN:
raise ValueError(
f"User '{target_user.email}' is an admin and therefore has all permissions "
"of a curator. If you'd like this user to only have curator permissions, "
"you must update their role to BASIC then assign them to be CURATOR in the "
"appropriate groups."
)
elif target_user.role == UserRole.GLOBAL_CURATOR:
raise ValueError(
f"User '{target_user.email}' is a global_curator and therefore has all "
"permissions of a curator for all groups. If you'd like this user to only "
"have curator permissions for a specific group, you must update their role "
"to BASIC then assign them to be CURATOR in the appropriate groups."
)
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
raise ValueError(
f"This endpoint can only be used to update the curator relationship for "
"users with the CURATOR or BASIC role. \n"
f"Target user: {target_user.email} \n"
f"Target user role: {target_user.role} \n"
)
# check if the target user is in the group they are changing the curator relationship for
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=set_curator_request.user_id,
user_id=target_user.id,
only_curator_groups=False,
)
group_ids = [group.id for group in requested_user_groups]
if user_group_id not in group_ids:
raise ValueError(f"user is not in group '{user_group_id}'")
raise ValueError(
f"target user {target_user.email} is not in group '{user_group_id}'"
)
def update_user_curator_relationship(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
user_making_change: User | None = None,
) -> None:
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not target_user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
_validate_curator_relationship_update_request(
db_session=db_session,
user_group_id=user_group_id,
target_user=target_user,
)
_validate_curator_relationship_update_requester(
db_session=db_session,
user_group_id=user_group_id,
user_making_change=user_making_change,
)
logger.info(
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
f"updating the curator relationship for user={target_user.email} "
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
)
relationship_to_update = (
db_session.query(User__UserGroup)
@@ -486,7 +562,7 @@ def update_user_curator_relationship(
)
db_session.add(relationship_to_update)
_validate_curator_status__no_commit(db_session, [user])
_validate_curator_status__no_commit(db_session, [target_user])
db_session.commit()

View File

@@ -0,0 +1,4 @@
# This is a group that we use to store all the users that we found in Confluence
# Instead of setting a page to public, we just add this group so that the page
# is only accessible to users who have confluence accounts.
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"

View File

@@ -4,6 +4,8 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
"""
from typing import Any
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
@@ -31,14 +33,32 @@ def _get_server_space_permissions(
permission_category.get("spacePermissions", [])
)
is_public = False
user_names = set()
group_names = set()
for permission in viewspace_permissions:
if user_name := permission.get("userName"):
user_name = permission.get("userName")
if user_name:
user_names.add(user_name)
if group_name := permission.get("groupName"):
group_name = permission.get("groupName")
if group_name:
group_names.add(group_name)
# It seems that if anonymous access is turned on for the site and space,
# then the space is publicly accessible.
# For confluence server, we make a group that contains all users
# that exist in confluence and then just add that group to the space permissions
# if anonymous access is turned on for the site and space or we set is_public = True
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
# that we can support confluence server deployments that want anonymous access
# to be public (we cant test this because its paywalled)
if user_name is None and group_name is None:
# Defaults to False
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
is_public = True
else:
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
user_emails = set()
for user_name in user_names:
user_email = get_user_email_from_username__server(confluence_client, user_name)
@@ -47,14 +67,17 @@ def _get_server_space_permissions(
else:
logger.warning(f"Email for user {user_name} not found in Confluence")
if not user_emails and not group_names:
logger.warning(
"No user emails or group names found in Confluence space permissions"
f"\nSpace key: {space_key}"
f"\nSpace permissions: {space_permissions}"
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
# TODO: Check if the space is publicly accessible
# Currently, we assume the space is not public
# We need to check if anonymous access is turned on for the site and space
# This information is paywalled so it remains unimplemented
is_public=False,
is_public=is_public,
)
@@ -134,7 +157,7 @@ def _get_space_permissions(
def _extract_read_access_restrictions(
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
) -> ExternalAccess | None:
) -> tuple[set[str], set[str]]:
"""
Converts a page's restrictions dict into an ExternalAccess object.
If there are no restrictions, then return None
@@ -177,21 +200,57 @@ def _extract_read_access_restrictions(
group["name"] for group in read_access_group_jsons if group.get("name")
]
return set(read_access_user_emails), set(read_access_group_names)
def _get_all_page_restrictions(
confluence_client: OnyxConfluence,
perm_sync_data: dict[str, Any],
) -> ExternalAccess | None:
"""
This function gets the restrictions for a page by taking the intersection
of the page's restrictions and the restrictions of all the ancestors
of the page.
If the page/ancestor has no restrictions, then it is ignored (no intersection).
If no restrictions are found anywhere, then return None, indicating that the page
should inherit the space's restrictions.
"""
found_user_emails: set[str] = set()
found_group_names: set[str] = set()
found_user_emails, found_group_names = _extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=perm_sync_data.get("restrictions", {}),
)
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
for ancestor in ancestors:
ancestor_user_emails, ancestor_group_names = _extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=ancestor.get("restrictions", {}),
)
if not ancestor_user_emails and not ancestor_group_names:
# This ancestor has no restrictions, so it has no effect on
# the page's restrictions, so we ignore it
continue
found_user_emails.intersection_update(ancestor_user_emails)
found_group_names.intersection_update(ancestor_group_names)
# If there are no restrictions found, then the page
# inherits the space's restrictions so return None
is_space_public = read_access_user_emails == [] and read_access_group_names == []
if is_space_public:
if not found_user_emails and not found_group_names:
return None
return ExternalAccess(
external_user_emails=set(read_access_user_emails),
external_user_group_ids=set(read_access_group_names),
external_user_emails=found_user_emails,
external_user_group_ids=found_group_names,
# there is no way for a page to be individually public if the space isn't public
is_public=False,
)
def _fetch_all_page_restrictions_for_space(
def _fetch_all_page_restrictions(
confluence_client: OnyxConfluence,
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
@@ -208,11 +267,11 @@ def _fetch_all_page_restrictions_for_space(
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
)
restrictions = _extract_read_access_restrictions(
if restrictions := _get_all_page_restrictions(
confluence_client=confluence_client,
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
)
if restrictions:
perm_sync_data=slim_doc.perm_sync_data,
):
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
@@ -301,7 +360,7 @@ def confluence_doc_sync(
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions_for_space(
return _fetch_all_page_restrictions(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,

View File

@@ -1,11 +1,11 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -30,6 +30,7 @@ def _build_group_member_email_map(
)
if not email:
# If we still don't have an email, skip this user
logger.warning(f"user result missing email field: {user_result}")
continue
for group in confluence_client.paginated_groups_by_user_retrieval(user):
@@ -53,6 +54,7 @@ def confluence_group_sync(
confluence_client=confluence_client,
)
onyx_groups: list[ExternalUserGroup] = []
all_found_emails = set()
for group_id, group_member_emails in group_member_email_map.items():
onyx_groups.append(
ExternalUserGroup(
@@ -60,5 +62,15 @@ def confluence_group_sync(
user_emails=list(group_member_emails),
)
)
all_found_emails.update(group_member_emails)
# This is so that when we find a public confleunce server page, we can
# give access to all users only in if they have an email in Confluence
if cc_pair.connector.connector_specific_config.get("is_cloud", False):
all_found_group = ExternalUserGroup(
id=ALL_CONF_EMAILS_GROUP_NAME,
user_emails=list(all_found_emails),
)
onyx_groups.append(all_found_group)
return onyx_groups

View File

@@ -0,0 +1,84 @@
from collections.abc import Callable
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.pipeline import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
DocumentSource,
# list of chunks to be censored and the user email. returns censored chunks
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
] = {
DocumentSource.SALESFORCE: censor_salesforce_chunks,
}
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
"""
Returns the set of sources that have censoring enabled.
This is based on if the access_type is set to sync and the connector
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
NOTE: This means if there is a source has a single cc_pair that is sync,
all chunks for that source will be censored, even if the connector that
indexed that chunk is not sync. This was done to avoid getting the cc_pair
for every single chunk.
"""
with get_session_context_manager() as db_session:
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
return {
cc_pair.connector.source
for cc_pair in enabled_sync_connectors
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
}
# NOTE: This is only called if ee is enabled.
def _post_query_chunk_censoring(
chunks: list[InferenceChunk],
user: User | None,
) -> list[InferenceChunk]:
"""
This function checks all chunks to see if they need to be sent to a censoring
function. If they do, it sends them to the censoring function and returns the
censored chunks. If they don't, it returns the original chunks.
"""
if user is None:
# if user is None, permissions are not enforced
return chunks
chunks_to_keep = []
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
for chunk in chunks:
# Separate out chunks that require permission post-processing by source
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)
# For each source, filter out the chunks using the permission
# check function for that source
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
for source, chunks_for_source in chunks_to_process.items():
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
try:
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
except Exception as e:
logger.exception(
f"Failed to censor chunks for source {source} so throwing out all"
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
return chunks_to_keep

View File

@@ -0,0 +1,226 @@
import time
from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
from ee.onyx.external_permissions.salesforce.utils import (
get_any_salesforce_client_for_doc_id,
)
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
from ee.onyx.external_permissions.salesforce.utils import (
get_salesforce_user_id_from_email,
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Types
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end
# NOTE: Used for testing timing
def _get_dummy_object_access_map(
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
) -> dict[str, bool]:
time.sleep(0.15)
# return {object_id: True for object_id in object_ids}
import random
return {object_id: random.choice([True, False]) for object_id in object_ids}
def _get_objects_access_for_user_email_from_salesforce(
object_ids: set[str],
user_email: str,
chunks: list[InferenceChunk],
) -> dict[str, bool] | None:
"""
This function wraps the salesforce call as we may want to change how this
is done in the future. (E.g. replace it with the above function)
"""
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries for this source are essentially instant
first_doc_id = chunks[0].document_id
with get_session_context_manager() as db_session:
salesforce_client = get_any_salesforce_client_for_doc_id(
db_session, first_doc_id
)
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries by the same user are essentially instant
start_time = time.time()
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
end_time = time.time()
logger.info(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
return None
# This is the only query that is not cached in the function
# so it takes 0.1-0.2 seconds total
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
return object_id_to_access
def _extract_salesforce_object_id_from_url(url: str) -> str:
return url.split("/")[-1]
def _get_object_ranges_for_chunk(
chunk: InferenceChunk,
) -> dict[str, list[ContentRange]]:
"""
Given a chunk, return a dictionary of salesforce object ids and the content ranges
for that object id in the current chunk
"""
if chunk.source_links is None:
return {}
object_ranges: dict[str, list[ContentRange]] = {}
end_index = None
descending_source_links = sorted(
chunk.source_links.items(), key=lambda x: x[0], reverse=True
)
for start_index, url in descending_source_links:
object_id = _extract_salesforce_object_id_from_url(url)
if object_id not in object_ranges:
object_ranges[object_id] = []
object_ranges[object_id].append((start_index, end_index))
end_index = start_index
return object_ranges
def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
"""
Create a copy of the unfiltered chunk where potentially sensitive content is removed
to be added later if the user has access to each of the sub-objects
"""
empty_censored_chunk = InferenceChunk(
**uncensored_chunk.model_dump(),
)
empty_censored_chunk.content = ""
empty_censored_chunk.blurb = ""
empty_censored_chunk.source_links = {}
return empty_censored_chunk
def _update_censored_chunk(
censored_chunk: InferenceChunk,
uncensored_chunk: InferenceChunk,
content_range: ContentRange,
) -> InferenceChunk:
"""
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
"""
start_index, end_index = content_range
# Update the content of the filtered chunk
permitted_content = uncensored_chunk.content[start_index:end_index]
permitted_section_start_index = len(censored_chunk.content)
censored_chunk.content = permitted_content + censored_chunk.content
# Update the source links of the filtered chunk
if uncensored_chunk.source_links is not None:
if censored_chunk.source_links is None:
censored_chunk.source_links = {}
link_content = uncensored_chunk.source_links[start_index]
censored_chunk.source_links[permitted_section_start_index] = link_content
# Update the blurb of the filtered chunk
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]
return censored_chunk
# TODO: Generalize this to other sources
def censor_salesforce_chunks(
chunks: list[InferenceChunk],
user_email: str,
# This is so we can provide a mock access map for testing
access_map: dict[str, bool] | None = None,
) -> list[InferenceChunk]:
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}
# (doc_id, chunk_id) -> chunk
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}
# keep track of all object ids that we have seen to make it easier to get
# the access for these object ids
object_ids: set[str] = set()
for chunk in chunks:
chunk_key = (chunk.document_id, chunk.chunk_id)
# create a dictionary to quickly look up the unfiltered chunk
uncensored_chunks[chunk_key] = chunk
# for each chunk, get a dictionary of object ids and the content ranges
# for that object id in the current chunk
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
for object_id, ranges in object_ranges_for_chunk.items():
object_ids.add(object_id)
for start_index, end_index in ranges:
object_to_content_map.setdefault(object_id, []).append(
(chunk_key, (start_index, end_index))
)
# This is so we can provide a mock access map for testing
if access_map is None:
access_map = _get_objects_access_for_user_email_from_salesforce(
object_ids=object_ids,
user_email=user_email,
chunks=chunks,
)
if access_map is None:
# If the user is not found in Salesforce, access_map will be None
# so we should just return an empty list because no chunks will be
# censored
return []
censored_chunks: dict[ChunkKey, InferenceChunk] = {}
for object_id, content_list in object_to_content_map.items():
# if the user does not have access to the object, or the object is not in the
# access_map, do not include its content in the filtered chunks
if not access_map.get(object_id, False):
continue
# if we got this far, the user has access to the object so we can create or update
# the filtered chunk(s) for this object
# NOTE: we only create a censored chunk if the user has access to some
# part of the chunk
for chunk_key, content_range in content_list:
if chunk_key not in censored_chunks:
censored_chunks[chunk_key] = _create_empty_censored_chunk(
uncensored_chunks[chunk_key]
)
uncensored_chunk = uncensored_chunks[chunk_key]
censored_chunk = _update_censored_chunk(
censored_chunk=censored_chunks[chunk_key],
uncensored_chunk=uncensored_chunk,
content_range=content_range,
)
censored_chunks[chunk_key] = censored_chunk
return list(censored_chunks.values())
# NOTE: This is not used anywhere.
def _get_objects_access_for_user_email(
object_ids: set[str], user_email: str
) -> dict[str, bool]:
with get_session_context_manager() as db_session:
external_groups = fetch_external_groups_for_user_email_and_group_ids(
db_session=db_session,
user_email=user_email,
# Maybe make a function that adds a salesforce prefix to the group ids
group_ids=list(object_ids),
)
external_group_ids = {group.external_user_group_id for group in external_groups}
return {group_id: group_id in external_group_ids for group_id in object_ids}

View File

@@ -0,0 +1,174 @@
from simple_salesforce import Salesforce
from sqlalchemy.orm import Session
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_cc_pairs_for_document
from onyx.utils.logger import setup_logger
logger = setup_logger()
_ANY_SALESFORCE_CLIENT: Salesforce | None = None
def get_any_salesforce_client_for_doc_id(
db_session: Session, doc_id: str
) -> Salesforce:
"""
We create a salesforce client for the first cc_pair for the first doc_id where
salesforce censoring is enabled. After that we just cache and reuse the same
client for all queries.
We do this to reduce the number of postgres queries we make at query time.
This may be problematic if they are using multiple cc_pairs for salesforce.
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
but only one has the permissions to access the permissions needed for the query.
"""
global _ANY_SALESFORCE_CLIENT
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
credential_json = first_cc_pair.credential.credential_json
_ANY_SALESFORCE_CLIENT = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
security_token=credential_json["sf_security_token"],
)
return _ANY_SALESFORCE_CLIENT
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
result = sf_client.query(query)
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
# This contains only the user_ids that we have found in Salesforce.
# If we don't know their user_id, we don't store anything in the cache.
_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {}
def get_salesforce_user_id_from_email(
sf_client: Salesforce,
user_email: str,
) -> str | None:
"""
We cache this so we don't have to query Salesforce for every query and salesforce
user IDs never change.
Memory usage is fine because we just store 2 small strings per user.
If the email is not in the cache, we check the local salesforce database for the info.
If the user is not found in the local salesforce database, we query Salesforce.
Whatever we get back from Salesforce is added to the database.
If no user_id is found, we add a NULL_ID_STRING to the database for that email so
we don't query Salesforce again (which is slow) but we still check the local salesforce
database every query until a user id is found. This is acceptable because the query time
is quite fast.
If a user_id is created in Salesforce, it will be added to the local salesforce database
next time the connector is run. Then that value will be found in this function and cached.
NOTE: First time this runs, it may be slow if it hasn't already been updated in the local
salesforce database. (Around 0.1-0.3 seconds)
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
"""
global _CACHED_SF_EMAIL_TO_ID_MAP
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
db_exists = True
try:
# Check if the user is already in the database
user_id = get_user_id_by_email(user_email)
except Exception:
init_db()
try:
user_id = get_user_id_by_email(user_email)
except Exception as e:
logger.error(f"Error checking if user is in database: {e}")
user_id = None
db_exists = False
# If no entry is found in the database (indicated by user_id being None)...
if user_id is None:
# ...query Salesforce and store the result in the database
user_id = _query_salesforce_user_id(sf_client, user_email)
if db_exists:
update_email_to_id_table(user_email, user_id)
return user_id
elif user_id is None:
return None
elif user_id == NULL_ID_STRING:
return None
# If the found user_id is real, cache it
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
return user_id
_MAX_RECORD_IDS_PER_QUERY = 200
def get_objects_access_for_user_id(
salesforce_client: Salesforce,
user_id: str,
record_ids: list[str],
) -> dict[str, bool]:
"""
Salesforce has a limit of 200 record ids per query. So we just truncate
the list of record ids to 200. We only ever retrieve 50 chunks at a time
so this should be fine (unlikely that we retrieve all 50 chunks contain
4 unique objects).
If we decide this isn't acceptable we can use multiple queries but they
should be in parallel so query time doesn't get too long.
"""
truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY]
record_ids_str = "'" + "','".join(truncated_record_ids) + "'"
access_query = f"""
SELECT RecordId, HasReadAccess
FROM UserRecordAccess
WHERE RecordId IN ({record_ids_str})
AND UserId = '{user_id}'
"""
result = salesforce_client.query_all(access_query)
return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]}
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {}
_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {}
# NOTE: This is not used anywhere.
def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce:
"""
Uses a document id to get the cc_pair that indexed that document and uses the credentials
for that cc_pair to create a Salesforce client.
Problems:
- There may be multiple cc_pairs for a document, and we don't know which one to use.
- right now we just use the first one
- Building a new Salesforce client for each document is slow.
- Memory usage could be an issue as we build these dictionaries.
"""
if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
_DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
security_token=credential_json["sf_security_token"],
)
return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id]

View File

@@ -8,6 +8,9 @@ from ee.onyx.external_permissions.confluence.group_sync import confluence_group_
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -71,4 +74,7 @@ EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return source_type in DOC_PERMISSIONS_FUNC_MAP
return (
source_type in DOC_PERMISSIONS_FUNC_MAP
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
)

View File

@@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base
from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
@@ -62,7 +63,7 @@ def get_application() -> FastAPI:
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
oauth_client,
@@ -74,19 +75,17 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
@@ -97,19 +96,20 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
),
prefix="/auth/oidc",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
elif AUTH_TYPE == AuthType.SAML:
include_router_with_global_prefix_prepended(application, saml_router)
include_auth_router_with_prefix(
application,
saml_router,
)
# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)

View File

@@ -1,17 +1,24 @@
import datetime
from collections import defaultdict
from typing import List
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.db.analytics import fetch_assistant_message_analytics
from ee.onyx.db.analytics import fetch_assistant_unique_users
from ee.onyx.db.analytics import fetch_assistant_unique_users_total
from ee.onyx.db.analytics import fetch_onyxbot_analytics
from ee.onyx.db.analytics import fetch_per_user_query_analytics
from ee.onyx.db.analytics import fetch_persona_message_analytics
from ee.onyx.db.analytics import fetch_persona_unique_users
from ee.onyx.db.analytics import fetch_query_analytics
from ee.onyx.db.analytics import user_can_view_assistant_stats
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.db.engine import get_session
from onyx.db.models import User
@@ -191,3 +198,74 @@ def get_persona_unique_users(
)
)
return unique_user_counts
class AssistantDailyUsageResponse(BaseModel):
date: datetime.date
total_messages: int
total_unique_users: int
class AssistantStatsResponse(BaseModel):
daily_stats: List[AssistantDailyUsageResponse]
total_messages: int
total_unique_users: int
@router.get("/assistant/{assistant_id}/stats")
def get_assistant_stats(
assistant_id: int,
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantStatsResponse:
"""
Returns daily message and unique user counts for a user's assistant,
along with the overall total messages and total distinct users.
"""
start = start or (
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
)
end = end or datetime.datetime.utcnow()
if not user_can_view_assistant_stats(db_session, user, assistant_id):
raise HTTPException(
status_code=403, detail="Not allowed to access this assistant's stats."
)
# Pull daily usage from the DB calls
messages_data = fetch_assistant_message_analytics(
db_session, assistant_id, start, end
)
unique_users_data = fetch_assistant_unique_users(
db_session, assistant_id, start, end
)
# Map each day => (messages, unique_users).
daily_messages_map = {date: count for count, date in messages_data}
daily_unique_users_map = {date: count for count, date in unique_users_data}
all_dates = set(daily_messages_map.keys()) | set(daily_unique_users_map.keys())
# Merge both sets of metrics by date
daily_results: list[AssistantDailyUsageResponse] = []
for date in sorted(all_dates):
daily_results.append(
AssistantDailyUsageResponse(
date=date,
total_messages=daily_messages_map.get(date, 0),
total_unique_users=daily_unique_users_map.get(date, 0),
)
)
# Now pull a single total distinct user count across the entire time range
total_msgs = sum(d.total_messages for d in daily_results)
total_users = fetch_assistant_unique_users_total(
db_session, assistant_id, start, end
)
return AssistantStatsResponse(
daily_stats=daily_results,
total_messages=total_msgs,
total_unique_users=total_users,
)

View File

@@ -2,15 +2,16 @@ import logging
from collections.abc import Awaitable
from collections.abc import Callable
import jwt
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -22,11 +23,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
tenant_id = (
_get_tenant_id_from_request(request, logger)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
if MULTI_TENANT:
tenant_id = await _get_tenant_id_from_request(request, logger)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return await call_next(request)
@@ -35,27 +36,46 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
raise
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
# First check for API key
async def _get_tenant_id_from_request(
request: Request, logger: logging.LoggerAdapter
) -> str:
"""
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
if tenant_id:
return tenant_id
# Check for cookie-based auth
token = request.cookies.get("fastapiusersauth")
if not token:
return POSTGRES_DEFAULT_SCHEMA
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
# Since payload.get() can return None, ensure we have a string
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Since token_data.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
@@ -67,9 +87,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter)
return tenant_id
except jwt.InvalidTokenError:
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -1,5 +1,7 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
@@ -10,11 +12,29 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
@@ -62,14 +82,7 @@ class SlackOAuth:
@classmethod
def generate_oauth_url(cls, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.REDIRECT_URI}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
@@ -77,10 +90,14 @@ class SlackOAuth:
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
@@ -102,82 +119,151 @@ class SlackOAuth:
return session
# Work in progress
# class ConfluenceCloudOAuth:
# """work in progress"""
class ConfluenceCloudOAuth:
"""work in progress"""
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
# class OAuthSession(BaseModel):
# """Stored in redis to be looked up on callback"""
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
# email: str
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
# CONFLUENCE_OAUTH_SCOPE = (
# "read:confluence-props%20"
# "read:confluence-content.all%20"
# "read:confluence-content.summary%20"
# "read:confluence-content.permission%20"
# "read:confluence-user%20"
# "read:confluence-groups%20"
# "readonly:content.attachment:confluence"
# )
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# # eventually for Confluence Data Center
# # oauth_url = (
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# # f"&redirect_uri={redirectme_uri}"
# # )
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
# @classmethod
# def generate_oauth_url(cls, state: str) -> str:
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
# @classmethod
# def generate_dev_oauth_url(cls, state: str) -> str:
# """dev mode workaround for localhost testing
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
# """
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
# @classmethod
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# url = (
# "https://auth.atlassian.com/authorize"
# f"?audience=api.atlassian.com"
# f"&client_id={cls.CLIENT_ID}"
# f"&redirect_uri={redirect_uri}"
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
# f"&state={state}"
# "&response_type=code"
# "&prompt=consent"
# )
# return url
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
# @classmethod
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
# """Temporary state to store in redis. to be looked up on auth response.
# Returns a json string.
# """
# session = ConfluenceCloudOAuth.OAuthSession(
# email=email, redirect_on_success=redirect_on_success
# )
# return session.model_dump_json()
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
# @classmethod
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
# return session
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
@@ -192,8 +278,11 @@ def prepare_authorization_request(
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
@@ -203,6 +292,11 @@ def prepare_authorization_request(
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
@@ -210,8 +304,6 @@ def prepare_authorization_request(
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
# elif connector == DocumentSource.GOOGLE_DRIVE:
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
@@ -223,6 +315,7 @@ def prepare_authorization_request(
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
@@ -421,3 +514,116 @@ def handle_slack_oauth_callback(
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -13,9 +13,8 @@ from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
from ee.onyx.db.usage_export import write_usage_report
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
from onyx.auth.schemas import UserStatus
from onyx.configs.constants import FileOrigin
from onyx.db.users import list_users
from onyx.db.users import get_all_users
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
from onyx.file_store.file_store import FileStore
from onyx.file_store.file_store import get_default_file_store
@@ -84,15 +83,15 @@ def generate_user_report(
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
) as temp_file:
csvwriter = csv.writer(temp_file, delimiter=",")
csvwriter.writerow(["user_id", "status"])
csvwriter.writerow(["user_id", "is_active"])
users = list_users(db_session)
users = get_all_users(db_session)
for user in users:
user_skeleton = UserSkeleton(
user_id=str(user.id),
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
is_active=user.is_active,
)
csvwriter.writerow([user_skeleton.user_id, user_skeleton.status])
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
temp_file.seek(0)
file_store.save_file(

View File

@@ -4,8 +4,6 @@ from uuid import UUID
from pydantic import BaseModel
from onyx.auth.schemas import UserStatus
class FlowType(str, Enum):
CHAT = "chat"
@@ -22,7 +20,7 @@ class ChatMessageSkeleton(BaseModel):
class UserSkeleton(BaseModel):
user_id: str
status: UserStatus
is_active: bool
class UsageReportMetadata(BaseModel):

View File

@@ -0,0 +1,59 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import TenantAnonymousUserPath
def get_anonymous_user_path(tenant_id: str, db_session: Session) -> str | None:
result = db_session.execute(
select(TenantAnonymousUserPath).where(
TenantAnonymousUserPath.tenant_id == tenant_id
)
)
result_scalar = result.scalar_one_or_none()
if result_scalar:
return result_scalar.anonymous_user_path
else:
return None
def modify_anonymous_user_path(
tenant_id: str, anonymous_user_path: str, db_session: Session
) -> None:
# Enforce lowercase path at DB operation level
anonymous_user_path = anonymous_user_path.lower()
existing_entry = (
db_session.query(TenantAnonymousUserPath).filter_by(tenant_id=tenant_id).first()
)
if existing_entry:
existing_entry.anonymous_user_path = anonymous_user_path
else:
new_entry = TenantAnonymousUserPath(
tenant_id=tenant_id, anonymous_user_path=anonymous_user_path
)
db_session.add(new_entry)
db_session.commit()
def get_tenant_id_for_anonymous_user_path(
anonymous_user_path: str, db_session: Session
) -> str | None:
result = db_session.execute(
select(TenantAnonymousUserPath).where(
TenantAnonymousUserPath.anonymous_user_path == anonymous_user_path
)
)
result_scalar = result.scalar_one_or_none()
if result_scalar:
return result_scalar.tenant_id
else:
return None
def validate_anonymous_user_path(path: str) -> None:
if not path or "/" in path or not path.replace("-", "").isalnum():
raise ValueError("Invalid path. Use only letters, numbers, and hyphens.")

View File

@@ -3,35 +3,124 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_jwt_strategy
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
tenant_id: str | None = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_tenant(tenant_id=None) as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
tenant_id: str = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> None:
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_tenant(tenant_id=None) as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_tenant(tenant_id=None) as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
@@ -103,7 +192,7 @@ async def impersonate_user(
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_jwt_strategy().write_token(user_to_impersonate)
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
@@ -114,3 +203,48 @@ async def impersonate_user(
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -46,6 +46,7 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
"""
Send a request to the control service to register the number of users for a tenant.
"""
if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")

View File

@@ -39,3 +39,12 @@ class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
referral_source: str | None = None
class TenantDeletionPayload(BaseModel):
tenant_id: str
email: str
class AnonymousUserPath(BaseModel):
anonymous_user_path: str | None

View File

@@ -15,6 +15,7 @@ from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import drop_schema
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
@@ -185,6 +186,7 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
@@ -320,3 +322,26 @@ async def submit_to_hubspot(
if response.status_code != 200:
logger.error(f"Failed to submit to HubSpot: {response.text}")
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)

View File

@@ -68,3 +68,11 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()

View File

@@ -83,7 +83,7 @@ def patch_user_group(
def set_user_curator(
user_group_id: int,
set_curator_request: SetCuratorRequest,
_: User | None = Depends(current_admin_user),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
@@ -91,6 +91,7 @@ def set_user_curator(
db_session=db_session,
user_group_id=user_group_id,
set_curator_request=set_curator_request,
user_making_change=user,
)
except ValueError as e:
logger.error(f"Error setting user curator: {e}")

View File

@@ -1,3 +1,5 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
@@ -6,13 +8,27 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
posthog = Posthog(project_api_key=POSTHOG_API_KEY, host=POSTHOG_HOST)
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)
def event_telemetry(
distinct_id: str,
event: str,
properties: dict | None = None,
distinct_id: str, event: str, properties: dict | None = None
) -> None:
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
posthog.capture(distinct_id, event, properties)
"""Capture and send an event to PostHog, flushing immediately."""
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
try:
posthog.capture(distinct_id, event, properties)
posthog.flush()
except Exception as e:
logger.error(f"Error capturing PostHog event: {e}")

View File

@@ -1,5 +1,6 @@
import asyncio
import json
import time
from types import TracebackType
from typing import cast
from typing import Optional
@@ -320,8 +321,6 @@ async def embed_text(
api_url: str | None,
api_version: str | None,
) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
if not all(texts):
logger.error("Empty strings provided for embedding")
raise ValueError("Empty strings are not allowed for embedding.")
@@ -330,8 +329,17 @@ async def embed_text(
logger.error("No texts provided for embedding")
raise ValueError("No texts provided for embedding.")
start = time.monotonic()
total_chars = 0
for text in texts:
total_chars += len(text)
if provider_type is not None:
logger.debug(f"Using cloud provider {provider_type} for embedding")
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
)
if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
@@ -363,8 +371,16 @@ async def embed_text(
logger.error(error_message)
raise ValueError(error_message)
elapsed = time.monotonic() - start
logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with provider {provider_type} in {elapsed:.2f}"
)
elif model_name is not None:
logger.debug(f"Using local model {model_name} for embedding")
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
)
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
@@ -382,13 +398,17 @@ async def embed_text(
for embedding in embeddings_vectors
]
elapsed = time.monotonic() - start
logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with local model {model_name} in {elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
logger.info(f"Successfully embedded {len(texts)} texts")
return embeddings
@@ -440,7 +460,8 @@ async def process_embed_request(
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
elif not all(embed_request.texts):
if not all(embed_request.texts):
raise ValueError("Empty strings are not allowed for embedding.")
try:
@@ -471,9 +492,12 @@ async def process_embed_request(
detail=str(e),
)
except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
logger.exception(
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
)
raise HTTPException(
status_code=500, detail=f"Error during embedding process: {e}"
)
@router.post("/cross-encoder-scores")

View File

@@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
the files in the existing huggingface cache that don't exist in the temp
huggingface cache.
"""
for item in source.iterdir():
target_path = dest / item.relative_to(source)
if item.is_dir():

View File

@@ -0,0 +1,83 @@
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from textwrap import dedent
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
from onyx.configs.app_configs import SMTP_PASS
from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.models import User
def send_email(
user_email: str,
subject: str,
body: str,
mail_from: str = EMAIL_FROM,
) -> None:
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
msg = MIMEMultipart()
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg.attach(MIMEText(body))
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
except Exception as e:
raise e
def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Organization"
body = dedent(
f"""\
Hello,
You have been invited to join an organization on Onyx.
To join the organization, please visit the following link:
{WEB_DOMAIN}/auth/signup?email={user_email}
You'll be asked to set a password or login with Google to complete your registration.
Best regards,
The Onyx Team
"""
)
send_email(user_email, subject, body, current_user.email)
def send_forgot_password_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
body = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, body, mail_from)
def send_user_verification_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
subject = "Onyx Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
body = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, body, mail_from)

View File

@@ -30,13 +30,16 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
def fetch_no_auth_user(
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
) -> UserInfo:
return UserInfo(
id=NO_AUTH_USER_ID,
email=NO_AUTH_USER_EMAIL,
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.ADMIN,
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
is_anonymous_user=anonymous_user_enabled,
)

View File

@@ -33,12 +33,6 @@ class UserRole(str, Enum):
]
class UserStatus(str, Enum):
LIVE = "live"
INVITED = "invited"
DEACTIVATED = "deactivated"
class UserRead(schemas.BaseUser[uuid.UUID]):
role: UserRole
@@ -49,4 +43,7 @@ class UserCreate(schemas.BaseUserCreate):
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
"""
Role updates are not allowed through the user update endpoint for security reasons
Role changes should be handled through a separate, admin-only process
"""

View File

@@ -1,10 +1,9 @@
import smtplib
import json
import secrets
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import cast
from typing import Dict
from typing import List
@@ -32,10 +31,8 @@ from fastapi_users import schemas
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import JWTStrategy
from fastapi_users.authentication import RedisStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
@@ -49,23 +46,22 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_FROM
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import SMTP_PASS
from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
@@ -74,22 +70,23 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
from onyx.db.auth import get_access_token_db
from onyx.db.auth import get_default_admin_user_emails
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.server.utils import BasicAuthenticationError
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
@@ -103,6 +100,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -143,6 +145,17 @@ def user_needs_to_be_verified() -> bool:
return False
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
redis_client = get_redis_client(tenant_id=tenant_id)
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is None:
return False
assert isinstance(value, bytes)
return int(value.decode("utf-8")) == 1
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if not whitelist:
@@ -193,30 +206,6 @@ def verify_email_domain(email: str) -> None:
)
def send_user_verification_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
msg = MIMEMultipart()
msg["Subject"] = "Onyx Email Verification"
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
body = MIMEText(f"Click the following link to verify your email address: {link}")
msg.attach(body)
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
@@ -281,7 +270,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdate(
password=user_create.password,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
@@ -405,11 +393,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
@@ -428,7 +414,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(
@@ -506,7 +491,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
if not EMAIL_CONFIGURED:
logger.error(
"Email is not configured. Please configure email in the admin panel"
)
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Your admin has not enbaled this feature.",
)
send_forgot_password_email(user.email, token)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
@@ -583,51 +576,70 @@ cookie_transport = CookieTransport(
)
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
"""
A custom strategy that fetches the actual async Redis connection inside each method.
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
"""
def __init__(
self,
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
):
self.lifetime_seconds = lifetime_seconds
self.key_prefix = key_prefix
async def write_token(self, user: User) -> str:
redis = await get_async_redis_connection()
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user.email,
)
)(email=user.email)
data = {
token_data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
token = secrets.token_urlsafe()
await redis.set(
f"{self.key_prefix}{token}",
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
) -> Optional[User]:
redis = await get_async_redis_connection()
token_data_str = await redis.get(f"{self.key_prefix}{token}")
if not token_data_str:
return None
def get_jwt_strategy() -> TenantAwareJWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
)
try:
token_data = json.loads(token_data_str)
user_id = token_data["sub"]
parsed_id = user_manager.parse_id(user_id)
return await user_manager.get(parsed_id)
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
return None
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
async def destroy_token(self, token: str, user: User) -> None:
"""Properly delete the token from async redis."""
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
auth_backend = AuthenticationBackend(
name="jwt" if MULTI_TENANT else "database",
transport=cookie_transport,
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
) # type: ignore
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
@@ -713,30 +725,36 @@ async def double_check_user(
user: User | None,
optional: bool = DISABLE_AUTH,
include_expired: bool = False,
allow_anonymous_access: bool = False,
) -> User | None:
if optional:
return user
if user is not None:
# If user attempted to authenticate, verify them, do not default
# to anonymous access if it fails.
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
detail="Access denied. User is not verified.",
)
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
return user
if allow_anonymous_access:
return None
if user is None:
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
detail="Access denied. User is not verified.",
)
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
return user
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated.",
)
async def current_user_with_expired_token(
@@ -751,6 +769,15 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> User | None:
return await double_check_user(
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:

View File

@@ -335,6 +335,10 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not hasattr(sender, "primary_worker_lock"):
# primary_worker_lock will not exist when MULTI_TENANT is True
return
if not sender.primary_worker_lock:
return
@@ -414,11 +418,21 @@ def on_setup_logging(
task_logger.setLevel(loglevel)
task_logger.propagate = False
# Hide celery task received and succeeded/failed messages
# hide celery task received spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
strategy.logger.setLevel(logging.WARNING)
# uncomment this to hide celery task succeeded/failed spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
trace.logger.setLevel(logging.WARNING)
def set_task_finished_log_level(logLevel: int) -> None:
"""call this to override the setLevel in on_setup_logging. We are interested
in the task timings in the cloud but it can be spammy for self hosted."""
trace.logger.setLevel(logLevel)
class TenantContextFilter(logging.Filter):
"""Logging filter to inject tenant ID into the logger's name."""

View File

@@ -60,7 +60,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# rkuo: been seeing transient connection exceptions here, so upping the connection count
# from just concurrency/concurrency to concurrency/concurrency*2
SqlEngine.init_engine(
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -1,3 +1,4 @@
import logging
import multiprocessing
from typing import Any
from typing import cast
@@ -88,12 +89,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
@@ -194,6 +195,10 @@ def on_setup_logging(
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
# this can be spammy, so just enable it in the cloud for now
if MULTI_TENANT:
app_base.set_task_finished_log_level(logging.INFO)
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.
@@ -281,5 +286,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
]
)

View File

@@ -3,12 +3,54 @@ import json
from typing import Any
from typing import cast
from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.constants import OnyxCeleryPriority
def celery_get_unacked_length(r: Redis) -> int:
"""Checking the unacked queue is useful because a non-zero length tells us there
may be prefetched tasks.
There can be other tasks in here besides indexing tasks, so this is mostly useful
just to see if the task count is non zero.
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
"""
length = cast(int, r.hlen("unacked"))
return length
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queue are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()
for _, v in r.hscan_iter("unacked"):
v_bytes = cast(bytes, v)
v_str = v_bytes.decode("utf-8")
task = json.loads(v_str)
task_description = task[0]
task_queue = task[2]
if task_queue != queue:
continue
task_id = task_description.get("headers", {}).get("id")
if not task_id:
continue
# if the queue matches and we see the task_id, add it
tasks.add(task_id)
return tasks
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
@@ -47,3 +89,74 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
return True
return False
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if workers:
for worker_name in list(workers.keys()):
# if the name filter not set, return all worker names
if not name_filter:
worker_names.append(worker_name)
continue
# if the name filter is set, return only worker names that contain the name filter
if name_filter not in worker_name:
continue
worker_names.append(worker_name)
return worker_names
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of reserved tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
reserved_task_ids: set[str] = set()
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
if reserved_tasks:
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_task_ids.add(task["id"])
return reserved_task_ids
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of active tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
active_task_ids: set[str] = set()
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
if active_tasks:
for _, task_list in active_tasks.items():
for task in task_list:
active_task_ids.add(task["id"])
return active_task_ids

View File

@@ -16,6 +16,14 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Indexing worker specific ... this lets us track the transition to STARTED in redis
# We don't currently rely on this but it has the potential to be useful and
# indexing tasks are not high volume
# we don't turn this on yet because celery occasionally runs tasks more than once
# which means a duplicate run might change the task state unexpectedly
# task_track_started = True
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,9 +1,16 @@
from datetime import timedelta
from typing import Any
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
# choosing 15 minutes because it roughly gives us enough time to process many tasks
# we might be able to reduce this greatly if we can run a unified
# loop across all tenants rather than tasks per tenant
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
tasks_to_schedule = [
@@ -13,7 +20,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -22,7 +29,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -31,7 +38,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -40,7 +47,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -49,7 +56,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=3600),
"options": {
"priority": OnyxCeleryPriority.LOWEST,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -58,7 +65,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -67,7 +74,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -76,11 +83,25 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -34,7 +34,9 @@ class TaskDependencyError(RuntimeError):
trail=False,
bind=True,
)
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -42,11 +44,11 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -81,6 +83,8 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
if lock_beat.owned():
lock_beat.release()
return True
def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,

View File

@@ -1,6 +1,8 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from time import sleep
from uuid import uuid4
from celery import Celery
@@ -14,10 +16,14 @@ from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from ee.onyx.external_permissions.sync_params import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
@@ -88,19 +94,19 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -128,6 +134,8 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_permissions_sync_task(
app: Celery,
@@ -219,6 +227,43 @@ def connector_permission_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
)
if not redis_connector.permissions.fenced: # The fence must exist
raise ValueError(
f"connector_permission_sync_generator_task - fence not found: "
f"fence={redis_connector.permissions.fence_key}"
)
payload = redis_connector.permissions.payload # The payload must exist
if not payload:
raise ValueError(
"connector_permission_sync_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_permission_sync_generator_task - Waiting for fence: "
f"fence={redis_connector.permissions.fence_key}"
)
sleep(1)
continue
logger.info(
f"connector_permission_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.permissions.fence_key}"
)
break
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
@@ -244,6 +289,8 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
return None
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
@@ -254,8 +301,11 @@ def connector_permission_sync_generator_task(
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
new_payload = RedisConnectorPermissionSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.permissions.set_fence(new_payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)

View File

@@ -94,19 +94,19 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
@@ -149,6 +149,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_external_group_sync_task(
app: Celery,
@@ -162,7 +164,7 @@ def try_creating_external_group_sync_task(
LOCK_TIMEOUT = 30
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)

View File

@@ -1,9 +1,10 @@
import os
import sys
import time
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import Any
import redis
import sentry_sdk
@@ -18,10 +19,12 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
@@ -29,6 +32,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -58,6 +62,8 @@ from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -69,14 +75,18 @@ logger = setup_logger()
class IndexingCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
@@ -86,26 +96,54 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
# now = time.monotonic()
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
# try:
# # this is unintuitive, but it checks if the parent pid is still running
# os.kill(self.parent_pid, 0)
# except Exception:
# logger.exception("IndexingCallback - parent pid check exceptioned")
# raise
# self.last_parent_check = now
try:
self.redis_lock.reacquire()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned. "
f"IndexingCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
redis_lock_dump(self.redis_lock, self.redis_client)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
@@ -167,6 +205,10 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
debug_tenants = {
"tenant_i-043470d740845ec56",
"tenant_82b497ce-88aa-4fbd-841a-92cae43529c8",
}
time_start = time.monotonic()
tasks_created = 0
@@ -175,18 +217,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
locked = True
# check for search settings swap
@@ -209,15 +251,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
# gather cc_pair_ids
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair lock: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
@@ -226,22 +278,58 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
db_session
)
for search_settings_instance in search_settings_list:
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair search settings lock: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
if redis_connector_index.fenced:
continue
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing get_connector_credential_pair_from_id: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session
)
if not cc_pair:
continue
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing get_last_attempt_for_cc_pair: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair should index: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
@@ -274,6 +362,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair.id, None, db_session
)
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair try_creating_indexing_task: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
@@ -294,6 +391,26 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
tasks_created += 1
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair try_creating_indexing_task finished: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing unfenced lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
@@ -301,6 +418,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
db_session, redis_client
)
for attempt_id in unfenced_attempt_ids:
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing unfenced attempt id lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
@@ -318,24 +444,29 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
attempt.id, db_session, failure_reason=failure_reason
)
# rkuo: The following code logically appears to work, but the celery inspect code may be unstable
# turning off for the moment to see if it helps cloud stability
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing validate fences lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
# we want to run this less frequently than the overall task
# if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# task_logger.info("Validating indexing fences...")
# validate_indexing_fences(
# tenant_id, self.app, redis_client, redis_client_celery, lock_beat
# )
# except Exception:
# task_logger.exception("Exception while validating indexing fences")
# redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
# clear any indexing fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_indexing_fences(
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
)
except Exception:
task_logger.exception("Exception while validating indexing fences")
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -351,9 +482,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
redis_lock_dump(lock_beat, redis_client)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created
@@ -364,56 +496,20 @@ def validate_indexing_fences(
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_indexing_tasks: set[str] = set()
active_indexing_tasks: set[str] = set()
indexing_worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = celery_app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if not workers:
raise ValueError("No workers found!")
for worker_name in list(workers.keys()):
if "indexing" in worker_name:
indexing_worker_names.append(worker_name)
if len(indexing_worker_names) == 0:
raise ValueError("No indexing workers found!")
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
# NOTE: each dict entry is a map of worker name to a list of tasks
# we want sets for reserved task and active task id's to optimize
# subsequent validation lookups
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
if reserved_tasks is None:
raise ValueError("inspect_indexing.reserved() returned None!")
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_indexing_tasks.add(task["id"])
# get the list of active tasks
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
if active_tasks is None:
raise ValueError("inspect_indexing.active() returned None!")
for _, task_list in active_tasks.items():
for task in task_list:
active_indexing_tasks.add(task["id"])
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
active_indexing_tasks,
r_celery,
db_session,
)
@@ -424,7 +520,6 @@ def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
active_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
@@ -434,11 +529,15 @@ def validate_indexing_fence(
gives the help.
How this works:
1. Active signal is renewed with a 5 minute TTL
1.1 When the fence is created
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved or active list for a worker
2. The TTL allows us to get through the transitions on fence startup
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
@@ -466,6 +565,8 @@ def validate_indexing_fence(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
@@ -501,24 +602,24 @@ def validate_indexing_fence(
redis_connector_index.set_active()
return
if payload.celery_task_id in active_tasks:
# the celery task is active (aka currently executing)
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
# due to gaps in our ability to check states during transitions
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
@@ -783,7 +884,6 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -795,6 +895,58 @@ def connector_indexing_proxy_task(
while True:
sleep(5)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
try:
if job.status == "error":
ignore_exitcode = False
exit_code: int | None = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates successful completion
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if not ignore_exitcode:
raise RuntimeError("Spawned task exceptioned.")
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
except Exception:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
raise
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
@@ -821,75 +973,33 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
job.cancel()
job.cancel()
break
if not job.done():
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
if job.status == "error":
ignore_exitcode = False
exit_code: int | None = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates that they completed successfully
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
else:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
job.release()
break
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
@@ -918,7 +1028,7 @@ def connector_indexing_task_wrapper(
tenant_id,
is_ee,
)
except:
except Exception:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
@@ -926,7 +1036,14 @@ def connector_indexing_task_wrapper(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
# There is a cloud related bug outside of our code
# where spawned tasks return with an exit code of 1.
# Unfortunately, exceptions also return with an exit code of 1,
# so just raising an exception isn't informative
# Exiting with 255 makes it possible to distinguish between normal exits
# and exceptions.
sys.exit(255)
return result
@@ -998,7 +1115,17 @@ def connector_indexing_task(
f"fence={redis_connector.stop.fence_key}"
)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_indexing_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
)
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
@@ -1039,7 +1166,9 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return None
@@ -1075,6 +1204,7 @@ def connector_indexing_task(
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
@@ -1108,8 +1238,19 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
if attempt_found:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(
index_attempt_id, db_session, failure_reason=str(e)
)
except Exception:
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
finally:

View File

@@ -0,0 +1,105 @@
from typing import Any
import requests
from celery import shared_task
from celery import Task
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import LLMProvider
def _process_model_list_response(model_list_json: Any) -> list[str]:
# Handle case where response is wrapped in a "data" field
if isinstance(model_list_json, dict) and "data" in model_list_json:
model_list_json = model_list_json["data"]
if not isinstance(model_list_json, list):
raise ValueError(
f"Invalid response from API - expected list, got {type(model_list_json)}"
)
# Handle both string list and object list cases
model_names: list[str] = []
for item in model_list_json:
if isinstance(item, str):
model_names.append(item)
elif isinstance(item, dict) and "model_name" in item:
model_names.append(item["model_name"])
else:
raise ValueError(
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
)
return model_names
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
if not LLM_MODEL_UPDATE_API_URL:
raise ValueError("LLM model update API URL not configured")
# First fetch the models from the API
try:
response = requests.get(LLM_MODEL_UPDATE_API_URL)
response.raise_for_status()
available_models = _process_model_list_response(response.json())
task_logger.info(f"Found available models: {available_models}")
except Exception:
task_logger.exception("Failed to fetch models from API.")
return None
# Then update the database with the fetched models
with get_session_with_tenant(tenant_id) as db_session:
# Get the default LLM provider
default_provider = (
db_session.query(LLMProvider)
.filter(LLMProvider.is_default_provider.is_(True))
.first()
)
if not default_provider:
task_logger.warning("No default LLM provider found")
return None
# log change if any
old_models = set(default_provider.model_names or [])
new_models = set(available_models)
added_models = new_models - old_models
removed_models = old_models - new_models
if added_models:
task_logger.info(f"Adding models: {sorted(added_models)}")
if removed_models:
task_logger.info(f"Removing models: {sorted(removed_models)}")
# Update the provider's model list
default_provider.model_names = available_models
# if the default model is no longer available, set it to the first model in the list
if default_provider.default_model_name not in available_models:
task_logger.info(
f"Default model {default_provider.default_model_name} not "
f"available, setting to first model in list: {available_models[0]}"
)
default_provider.default_model_name = available_models[0]
if default_provider.fast_default_model_name not in available_models:
task_logger.info(
f"Fast default model {default_provider.fast_default_model_name} "
f"not available, setting to first model in list: {available_models[0]}"
)
default_provider.fast_default_model_name = available_models[0]
db_session.commit()
if added_models or removed_models:
task_logger.info("Updated model list for default provider.")
return True

View File

@@ -81,19 +81,19 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
@@ -127,6 +127,8 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_prune_generator_task(
celery_app: Celery,
@@ -283,6 +285,7 @@ def connector_pruning_generator_task(
)
callback = IndexingCallback(
0,
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,

View File

@@ -28,13 +28,35 @@ class RetryDocumentIndex:
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
stop=stop_after_delay(STOP_AFTER),
)
def delete_single(self, doc_id: str) -> int:
return self.index.delete_single(doc_id)
def delete_single(
self,
doc_id: str,
*,
tenant_id: str | None,
chunk_count: int | None,
) -> int:
return self.index.delete_single(
doc_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
@retry(
retry=retry_if_exception_type(httpx.ReadTimeout),
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
stop=stop_after_delay(STOP_AFTER),
)
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
return self.index.update_single(doc_id, fields)
def update_single(
self,
doc_id: str,
*,
tenant_id: str | None,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:
return self.index.update_single(
doc_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
fields=fields,
)

View File

@@ -12,6 +12,7 @@ from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocument
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
from onyx.db.document import delete_documents_complete__no_commit
from onyx.db.document import fetch_chunk_count_for_document
from onyx.db.document import get_document
from onyx.db.document import get_document_connector_count
from onyx.db.document import mark_document_as_modified
@@ -80,7 +81,13 @@ def document_by_cc_pair_cleanup_task(
# delete it from vespa and the db
action = "delete"
chunks_affected = retry_index.delete_single(document_id)
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
chunks_affected = retry_index.delete_single(
document_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
@@ -110,7 +117,12 @@ def document_by_cc_pair_cleanup_task(
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(document_id, fields=fields)
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(

View File

@@ -1,8 +1,10 @@
import random
import time
import traceback
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import Any
from typing import cast
import httpx
@@ -20,10 +22,12 @@ from tenacity import RetryError
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -67,6 +71,8 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -75,6 +81,7 @@ from onyx.utils.variable_functionality import (
)
from onyx.utils.variable_functionality import global_version
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -87,7 +94,7 @@ logger = setup_logger()
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
time_start = time.monotonic()
@@ -99,17 +106,18 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
self.app, db_session, r, lock_beat, tenant_id
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
)
# region document set scan
lock_beat.reacquire()
document_set_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
# check if any document sets are not synced
@@ -121,6 +129,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
document_set_ids.append(document_set.id)
for document_set_id in document_set_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
try_generate_document_set_sync_tasks(
self.app, document_set_id, db_session, r, lock_beat, tenant_id
@@ -129,6 +138,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
# check if any user groups are not synced
if global_version.is_ee_version():
lock_beat.reacquire()
try:
fetch_user_groups = fetch_versioned_implementation(
"onyx.db.user_group", "fetch_user_groups"
@@ -148,6 +159,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
usergroup_ids.append(usergroup.id)
for usergroup_id in usergroup_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
try_generate_user_group_sync_tasks(
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
@@ -162,14 +174,21 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_vespa_sync_task - Lock not owned on completion: "
f"tenant={tenant_id}"
)
redis_lock_dump(lock_beat, r)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
return
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
return True
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
max_tasks: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
@@ -200,11 +219,16 @@ def try_generate_stale_document_sync_tasks(
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
tasks_remaining = max_tasks
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
result = rc.generate_tasks(
tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
continue
@@ -218,10 +242,19 @@ def try_generate_stale_document_sync_tasks(
)
total_tasks_generated += result[0]
tasks_remaining -= result[0]
if tasks_remaining <= 0:
break
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
if tasks_remaining <= 0:
task_logger.info(
f"RedisConnector.generate_tasks reached the task generation limit: "
f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}"
)
else:
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
@@ -260,7 +293,9 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
result = rds.generate_tasks(
VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
return None
@@ -315,7 +350,9 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
result = rug.generate_tasks(
VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
return None
@@ -636,15 +673,23 @@ def monitor_ccpair_indexing_taskset(
if not payload:
return
elapsed_started_str = None
if payload.started:
elapsed_started = datetime.now(timezone.utc) - payload.started
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"Connector indexing progress: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
@@ -715,18 +760,21 @@ def monitor_ccpair_indexing_taskset(
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"Connector indexing finished: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
redis_connector_index.reset()
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
@@ -736,7 +784,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False if it exited early to prevent overlap
"""
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
time_start = time.monotonic()
timings: dict[str, Any] = {}
timings["start"] = time_start
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -744,55 +798,94 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return None
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return False
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
n_pruning = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
phase_start = time.monotonic()
# we don't need every tenant polling redis for this info.
if not MULTI_TENANT or random.randint(1, 10) == 10:
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
n_sync = celery_get_queue_length(
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
)
n_deletion = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
n_pruning = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
n_external_group_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
)
prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(prefetched)} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
)
timings["queues"] = time.monotonic() - phase_start
timings["queues_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# scan and monitor activity to completion
phase_start = time.monotonic()
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
timings["connector"] = time.monotonic() - phase_start
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
timings["connector_deletion"] = time.monotonic() - phase_start
timings["connector_deletion_ttl"] = r.ttl(
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK
)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
lock_beat.reacquire()
timings["documentset"] = time.monotonic() - phase_start
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
@@ -800,29 +893,45 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
lock_beat.reacquire()
timings["usergroup"] = time.monotonic() - phase_start
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
timings["pruning"] = time.monotonic() - phase_start
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
timings["indexing"] = time.monotonic() - phase_start
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["permissions"] = time.monotonic() - phase_start
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -830,6 +939,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"monitor_vespa_sync - Lock not owned on completion: "
f"tenant={tenant_id} "
f"timings={timings}"
)
redis_lock_dump(lock_beat, r)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
@@ -876,12 +992,26 @@ def vespa_metadata_sync_task(
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(document_id, fields)
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
# this code checks for and removes a per document sync key that is
# used to block out the same doc from continualy resyncing
# a quick hack that is only needed for production issues
# redis_syncing_key = RedisConnectorCredentialPair.make_redis_syncing_key(
# document_id
# )
# r = get_redis_client(tenant_id=tenant_id)
# r.delete(redis_syncing_key)
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")

View File

@@ -14,6 +14,7 @@ from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
@@ -90,6 +91,35 @@ def _get_connector_runner(
)
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
cleaned_batch = []
for doc in doc_batch:
cleaned_doc = doc.model_copy()
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link and "\x00" in section.link:
logger.warning(
f"NUL characters found in document link for document: {cleaned_doc.id}"
)
section.link = section.link.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
@@ -238,7 +268,9 @@ def _run_indexing(
)
batch_description = []
for doc in doc_batch:
doc_batch_cleaned = strip_null_characters(doc_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
@@ -258,15 +290,15 @@ def _run_indexing(
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
all_connector_doc_ids.update(doc.id for doc in doc_batch)
document_count += len(doc_batch_cleaned)
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
@@ -276,7 +308,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch))
callback.progress("_run_indexing", len(doc_batch_cleaned))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(

View File

@@ -22,7 +22,9 @@ from onyx.chat.stream_processing.answer_response_handler import (
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
@@ -206,9 +208,9 @@ class Answer:
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
search_result, displayed_search_results_map = SearchTool.get_search_result(
final_search_results, displayed_search_results = SearchTool.get_search_result(
current_llm_call
) or ([], {})
) or ([], [])
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
@@ -224,9 +226,9 @@ class Answer:
# else:
# raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
display_doc_order_dict=displayed_search_results_map,
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
response_handler_manager = LLMResponseHandlerManager(

View File

@@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
final_doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_id_to_rank_map: DocumentIdOrderMapping,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.display_doc_order_dict = display_doc_order_dict
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
display_doc_order_dict=self.display_doc_order_dict,
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []
# TODO remove this after citation issue is resolved
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")
def handle_response_part(
self,

View File

@@ -21,20 +21,19 @@ class CitationProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
final_doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.display_doc_order_dict = (
display_doc_order_dict # original order of docs to displayed to user
)
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
self.citation_order: list[int] = [] # order of citations in the LLM output
self.curr_segment = ""
self.cited_inds: set[int] = set()
self.hold = ""
@@ -93,29 +92,31 @@ class CitationProcessor:
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
real_citation_num = self.order_mapping[context_llm_doc.document_id]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if real_citation_num not in self.citation_order:
self.citation_order.append(real_citation_num)
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
target_citation_num = (
self.citation_order.index(real_citation_num) + 1
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_doc_order_dict:
displayed_citation_num = self.display_doc_order_dict[
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = real_citation_num
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
@@ -134,8 +135,8 @@ class CitationProcessor:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
@@ -151,13 +152,13 @@ class CitationProcessor:
link = context_llm_doc.link
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
self.current_citations.append(final_citation_num)
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
@@ -167,7 +168,6 @@ class CitationProcessor:
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
# + f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
@@ -176,7 +176,6 @@ class CitationProcessor:
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
# + f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length

View File

@@ -17,6 +17,7 @@ APP_PORT = 8080
# prefix from requests directed towards the API server. In these cases, set this to `/api`
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
#####
# User Facing Features Configs
@@ -54,10 +55,17 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# Default request timeout, mostly used by connectors
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
# restrict access to Onyx to only users with emails from those domains.
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
@@ -92,6 +100,7 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
# If set, Onyx will listen to the `expires_at` returned by the identity
@@ -145,7 +154,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
@@ -184,6 +193,27 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
if _rate_limit_window_seconds_str is not None:
try:
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
except ValueError:
pass
RATE_LIMIT_MAX_REQUESTS: int | None = None
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
if _rate_limit_max_requests_str is not None:
try:
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
except ValueError:
pass
AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
@@ -251,6 +281,11 @@ try:
except ValueError:
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 1024
DB_YIELD_PER_DEFAULT = 64
#####
# Connector Configs
#####
@@ -347,12 +382,17 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)
# Typically set to http://localhost:3000 for OAuth connector development
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
# Egnyte specific configs
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
# Linear specific configs
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -503,6 +543,9 @@ try:
except json.JSONDecodeError:
pass
# LLM Model Update API endpoint
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
#####
# Enterprise Edition Configs
#####
@@ -542,7 +585,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
#####
# API Key Configs
#####

View File

@@ -36,6 +36,8 @@ DISABLED_GEN_AI_MSG = (
DEFAULT_PERSONA_ID = 0
DEFAULT_CC_PAIR_ID = 1
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
@@ -74,13 +76,19 @@ KV_ENTERPRISE_SETTINGS_KEY = "onyx_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
# NOTE: we use this timeout / 4 in various places to refresh a lock
# might be worth separating this timeout into separate timeouts for each situation
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
@@ -134,9 +142,11 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
DISCORD = "discord"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
EGNYTE = "egnyte"
AIRTABLE = "airtable"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -240,6 +250,7 @@ class OnyxCeleryQueues:
VESPA_METADATA_SYNC = "vespa_metadata_sync"
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
LLM_MODEL_UPDATE = "llm_model_update"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
@@ -273,6 +284,7 @@ class OnyxRedisLocks:
SLACK_BOT_LOCK = "da_lock:slack_bot"
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
class OnyxRedisSignals:
@@ -294,6 +306,7 @@ class OnyxCeleryTask:
CHECK_FOR_PRUNING = "check_for_pruning"
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (

View File

@@ -0,0 +1,289 @@
from io import BytesIO
from typing import Any
import requests
from pyairtable import Api as AirtableApi
from pyairtable.api.types import RecordDict
from pyairtable.models.schema import TableSchema
from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: all are made lowercase to avoid case sensitivity issues
# these are the field types that are considered metadata rather
# than sections
_METADATA_FIELD_TYPES = {
"singlecollaborator",
"collaborator",
"createdby",
"singleselect",
"multipleselects",
"checkbox",
"date",
"datetime",
"email",
"phone",
"url",
"number",
"currency",
"duration",
"percent",
"rating",
"createdtime",
"lastmodifiedtime",
"autonumber",
"rollup",
"lookup",
"count",
"formula",
"date",
}
class AirtableClientNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__("Airtable Client is not set up, was load_credentials called?")
class AirtableConnector(LoadConnector):
def __init__(
self,
base_id: str,
table_name_or_id: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.batch_size = batch_size
self.airtable_client: AirtableApi | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
return None
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
"""
Extract value(s) from a field regardless of its type.
Returns either a single string or list of strings for attachments.
"""
if field_info is None:
return []
# skip references to other records for now (would need to do another
# request to get the actual record name/type)
# TODO: support this
if field_type == "multipleRecordLinks":
return []
if field_type == "multipleAttachments":
attachment_texts: list[str] = []
for attachment in field_info:
url = attachment.get("url")
filename = attachment.get("filename", "")
if not url:
continue
@retry(
tries=5,
delay=1,
backoff=2,
max_delay=10,
)
def get_attachment_with_retry(url: str) -> bytes | None:
attachment_response = requests.get(url)
if attachment_response.status_code == 200:
return attachment_response.content
return None
attachment_content = get_attachment_with_retry(url)
if attachment_content:
try:
file_ext = get_file_ext(filename)
attachment_text = extract_file_text(
BytesIO(attachment_content),
filename,
break_on_unprocessable=False,
extension=file_ext,
)
if attachment_text:
attachment_texts.append(f"{filename}:\n{attachment_text}")
except Exception as e:
logger.warning(
f"Failed to process attachment {filename}: {str(e)}"
)
return attachment_texts
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
combined = []
collab_name = field_info.get("name")
collab_email = field_info.get("email")
if collab_name:
combined.append(collab_name)
if collab_email:
combined.append(f"({collab_email})")
return [" ".join(combined) if combined else str(field_info)]
if isinstance(field_info, list):
return [str(item) for item in field_info]
return [str(field_info)]
def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata."""
return field_type.lower() in _METADATA_FIELD_TYPES
def _process_field(
self,
field_name: str,
field_info: Any,
field_type: str,
table_id: str,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
"""
Process a single Airtable field and return sections or metadata.
Args:
field_name: Name of the field
field_info: Raw field information from Airtable
field_type: Airtable field type
Returns:
(list of Sections, dict of metadata)
"""
if field_info is None:
return [], {}
# Get the value(s) for the field
field_values = self._get_field_value(field_info, field_type)
if len(field_values) == 0:
return [], {}
# Determine if it should be metadata or a section
if self._should_be_metadata(field_type):
if len(field_values) > 1:
return [], {field_name: field_values}
return [], {field_name: field_values[0]}
# Otherwise, create relevant sections
sections = [
Section(
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
text=(
f"{field_name}:\n"
"------------------------\n"
f"{text}\n"
"------------------------"
),
)
for text in field_values
]
return sections, {}
def _process_record(
self,
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
) -> Document:
"""Process a single Airtable record into a Document.
Args:
record: The Airtable record to process
table_schema: Schema information for the table
table_name: Name of the table
table_id: ID of the table
primary_field_name: Name of the primary field, if any
Returns:
Document object representing the record
"""
table_id = table_schema.id
table_name = table_schema.name
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
# Get primary field value if it exists
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
)
for field_schema in table_schema.fields:
field_name = field_schema.name
field_val = fields.get(field_name)
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
return Document(
id=f"airtable__{record_id}",
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
)
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from the table.
NOTE: Airtable does not support filtering by time updated, so
we have to fetch all records every time.
"""
if not self.airtable_client:
raise AirtableClientNotSetUpError()
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
records = table.all()
table_schema = table.schema()
primary_field_name = None
# Find a primary field from the schema
for field in table_schema.fields:
if field.id == table_schema.primary_field_id:
primary_field_name = field.name
break
record_documents: list[Document] = []
for record in records:
document = self._process_record(
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
record_documents.append(document)
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []
if record_documents:
yield record_documents

View File

@@ -52,10 +52,29 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"space",
"restrictions.read.restrictions.user",
"restrictions.read.restrictions.group",
"ancestors.restrictions.read.restrictions.user",
"ancestors.restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -64,7 +83,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
is_cloud: bool,
space: str = "",
page_id: str = "",
index_recursively: bool = True,
index_recursively: bool = False,
cql_query: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
@@ -82,23 +101,25 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
# if nothing is provided, we will fetch all pages
cql_page_query = "type=page"
"""
If nothing is provided, we default to fetching all pages
Only one or none of the following options should be specified so
the order shouldn't matter
However, we use elif to ensure that only of the following is enforced
"""
base_cql_page_query = "type=page"
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
base_cql_page_query = cql_query
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')"
else:
cql_page_query += f" and id='{page_id}'"
base_cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
uri_safe_space = quote(space)
base_cql_page_query += f" and space='{uri_safe_space}'"
self.cql_page_query = cql_page_query
self.cql_time_filter = ""
self.base_cql_page_query = base_cql_page_query
self.cql_label_filter = ""
if labels_to_skip:
@@ -126,6 +147,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _construct_page_query(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
page_query = self.base_cql_page_query + self.cql_label_filter
# Add time filters
if start:
formatted_start_time = datetime.fromtimestamp(
start, tz=self.timezone
).strftime("%Y-%m-%d %H:%M")
page_query += f" and lastmodified >= '{formatted_start_time}'"
if end:
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
comment_string = ""
@@ -205,11 +253,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
metadata=doc_metadata,
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
def _fetch_document_batches(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
page_query = self._construct_page_query(start, end)
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
@@ -228,11 +280,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(confluence_page_id)
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
@@ -248,17 +299,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_document_batches()
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def poll_source(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)
def retrieve_all_slim_documents(
self,
@@ -269,7 +315,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -279,9 +325,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# These will be used by doc_sync.py to sync permissions
page_restrictions = page.get("restrictions")
page_space_key = page.get("space", {}).get("key")
page_ancestors = page.get("ancestors", [])
page_perm_sync_data = {
"restrictions": page_restrictions or {},
"space_key": page_space_key,
"ancestors": page_ancestors or [],
}
doc_metadata_list.append(
@@ -294,10 +342,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=page_perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
cql=attachment_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):

View File

@@ -121,6 +121,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence(Confluence):
@@ -186,27 +187,67 @@ class OnyxConfluence(Confluence):
url_suffix += f"{connection_char}limit={limit}"
while url_suffix:
logger.debug(f"Making confluence call to {url_suffix}")
try:
logger.debug(f"Making confluence call to {url_suffix}")
next_response = self.get(url_suffix)
raw_response = self.get(
path=url_suffix,
advanced_mode=True,
)
except Exception as e:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
try:
raw_response.raise_for_status()
except Exception as e:
logger.warning(f"Error in confluence call to {url_suffix}")
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
if _PROBLEMATIC_EXPANSIONS in url_suffix:
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
if (
raw_response.status_code == 500
and limit > _MINIMUM_PAGINATION_LIMIT
):
new_limit = limit // 2
logger.warning(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
f"Reducing limit from {limit} to {new_limit} and trying again."
)
url_suffix = url_suffix.replace(
f"limit={limit}", f"limit={new_limit}"
)
limit = new_limit
continue
logger.exception(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
raise e
try:
next_response = raw_response.json()
except Exception as e:
logger.exception(
f"Failed to parse response as JSON. Response: {raw_response.__dict__}"
)
continue
raise e
# yield the results individually
yield from next_response.get("results", [])

View File

@@ -32,11 +32,13 @@ def get_user_email_from_username__server(
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
# For now, we'll just return a string that indicates failure
# We may want to revert to returning None in the future
# email = None
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]

View File

@@ -6,6 +6,7 @@ from typing import TypeVar
from dateutil.parser import parse
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.connectors.models import BasicExpertInfo
from onyx.utils.text_processing import is_valid_email
@@ -71,3 +72,10 @@ def process_in_batches(
def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
if CONNECTOR_LOCALHOST_OVERRIDE:
# Used for development
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"

View File

@@ -0,0 +1,321 @@
import asyncio
from collections.abc import AsyncIterable
from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from discord import Client
from discord.channel import TextChannel
from discord.channel import Thread
from discord.enums import MessageType
from discord.flags import Intents
from discord.message import Message as DiscordMessage
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
_SNIPPET_LENGTH = 30
def _convert_message_to_document(
message: DiscordMessage,
sections: list[Section],
) -> Document:
"""
Convert a discord message to a document
Sections are collected before calling this function because it relies on async
calls to fetch the thread history if there is one
"""
metadata: dict[str, str | list[str]] = {}
semantic_substring = ""
# Only messages from TextChannels will make it here but we have to check for it anyways
if isinstance(message.channel, TextChannel) and (
channel_name := message.channel.name
):
metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}"
# Single messages dont have a title
title = ""
# If there is a thread, add more detail to the metadata, title, and semantic identifier
if isinstance(message.channel, Thread):
# Threads do have a title
title = message.channel.name
# If its a thread, update the metadata, title, and semantic_substring
metadata["Thread"] = title
# Add more detail to the semantic identifier if available
semantic_substring += f" in Thread: {title}"
snippet: str = (
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
if len(message.content) > _SNIPPET_LENGTH
else message.content
)
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
return Document(
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
source=DocumentSource.DISCORD,
semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at,
title=title,
sections=sections,
metadata=metadata,
)
async def _fetch_filtered_channels(
discord_client: Client,
server_ids: list[int] | None,
channel_names: list[str] | None,
) -> list[TextChannel]:
filtered_channels: list[TextChannel] = []
for channel in discord_client.get_all_channels():
if not channel.permissions_for(channel.guild.me).read_message_history:
continue
if not isinstance(channel, TextChannel):
continue
if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids:
continue
if channel_names and channel.name not in channel_names:
continue
filtered_channels.append(channel)
logger.info(f"Found {len(filtered_channels)} channels for the authenticated user")
return filtered_channels
async def _fetch_documents_from_channel(
channel: TextChannel,
start_time: datetime | None,
end_time: datetime | None,
) -> AsyncIterable[Document]:
# Discord's epoch starts at 2015-01-01
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
if start_time and start_time < discord_epoch:
start_time = discord_epoch
async for channel_message in channel.history(
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if channel_message.type != MessageType.default:
continue
sections: list[Section] = [
Section(
text=channel_message.content,
link=channel_message.jump_url,
)
]
yield _convert_message_to_document(channel_message, sections)
for active_thread in channel.threads:
async for thread_message in active_thread.history(
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if thread_message.type != MessageType.default:
continue
sections = [
Section(
text=thread_message.content,
link=thread_message.jump_url,
)
]
yield _convert_message_to_document(thread_message, sections)
async for archived_thread in channel.archived_threads():
async for thread_message in archived_thread.history(
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if thread_message.type != MessageType.default:
continue
sections = [
Section(
text=thread_message.content,
link=thread_message.jump_url,
)
]
yield _convert_message_to_document(thread_message, sections)
def _manage_async_retrieval(
token: str,
requested_start_date_string: str,
channel_names: list[str],
server_ids: list[int],
start: datetime | None = None,
end: datetime | None = None,
) -> Iterable[Document]:
# parse requested_start_date_string to datetime
pull_date: datetime | None = (
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
tzinfo=timezone.utc
)
if requested_start_date_string
else None
)
# Set start_time to the later of start and pull_date, or whichever is provided
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
end_time: datetime | None = end
async def _async_fetch() -> AsyncIterable[Document]:
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents) as discord_client:
asyncio.create_task(discord_client.start(token))
await discord_client.wait_until_ready()
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
discord_client=discord_client,
server_ids=server_ids,
channel_names=channel_names,
)
for channel in filtered_channels:
async for doc in _fetch_documents_from_channel(
channel=channel,
start_time=start_time,
end_time=end_time,
):
yield doc
def run_and_yield() -> Iterable[Document]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _async_fetch()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
# Run the coroutine to get the next document
doc = loop.run_until_complete(next_coro)
yield doc
except StopAsyncIteration:
break
finally:
loop.close()
return run_and_yield()
class DiscordConnector(PollConnector, LoadConnector):
def __init__(
self,
server_ids: list[str] = [],
channel_names: list[str] = [],
# YYYY-MM-DD
start_date: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
):
self.batch_size = batch_size
self.channel_names: list[str] = channel_names if channel_names else []
self.server_ids: list[int] = (
[int(server_id) for server_id in server_ids] if server_ids else []
)
self._discord_bot_token: str | None = None
self.requested_start_date_string: str = start_date or ""
@property
def discord_bot_token(self) -> str:
if self._discord_bot_token is None:
raise ConnectorMissingCredentialError("Discord")
return self._discord_bot_token
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._discord_bot_token = credentials["discord_bot_token"]
return None
def _manage_doc_batching(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateDocumentsOutput:
doc_batch = []
for doc in _manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=start,
end=end,
):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
return self._manage_doc_batching(
datetime.fromtimestamp(start, tz=timezone.utc),
datetime.fromtimestamp(end, tz=timezone.utc),
)
def load_from_state(self) -> GenerateDocumentsOutput:
return self._manage_doc_batching(None, None)
if __name__ == "__main__":
import os
import time
end = time.time()
# 1 day
start = end - 24 * 60 * 60 * 1
# "1,2,3"
server_ids: str | None = os.environ.get("server_ids", None)
# "channel1,channel2"
channel_names: str | None = os.environ.get("channel_names", None)
connector = DiscordConnector(
server_ids=server_ids.split(",") if server_ids else [],
channel_names=channel_names.split(",") if channel_names else [],
start_date=os.environ.get("start_date", None),
)
connector.load_credentials(
{"discord_bot_token": os.environ.get("discord_bot_token")}
)
for doc_batch in connector.poll_source(start, end):
for doc in doc_batch:
print(doc)

View File

@@ -190,7 +190,7 @@ class DiscourseConnector(PollConnector):
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
page = 1
page = 0
while topic_ids := self._get_latest_topics(start, end, page):
doc_batch: list[Document] = []
for topic_id in topic_ids:

View File

@@ -3,20 +3,19 @@ import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from logging import Logger
from typing import Any
from typing import cast
from typing import IO
from urllib.parse import quote
import requests
from retry import retry
from pydantic import Field
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
@@ -33,53 +32,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
logger = setup_logger()
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
_TIMEOUT = 60
def _request_with_retries(
method: str,
url: str,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = _TIMEOUT,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method,
url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code != 403:
logger.exception(
f"Failed to call Egnyte API.\n"
f"URL: {url}\n"
f"Headers: {headers}\n"
f"Data: {data}\n"
f"Params: {params}"
)
raise e
return response
return _make_request()
def _parse_last_modified(last_modified: str) -> datetime:
@@ -166,6 +125,15 @@ def _process_egnyte_file(
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
egnyte_domain: str = Field(
title="Egnyte Domain",
description=(
"The domain for the Egnyte instance "
"(e.g. 'company' for company.egnyte.com)"
),
)
def __init__(
self,
folder_path: str | None = None,
@@ -181,18 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
return DocumentSource.EGNYTE
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
if EGNYTE_LOCALHOST_OVERRIDE:
base_domain = EGNYTE_LOCALHOST_OVERRIDE
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
@@ -201,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
)
@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
@@ -222,7 +198,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = _request_with_retries(
response = request_with_retries(
method="POST",
url=url,
data=data,
@@ -236,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"domain": oauth_kwargs.egnyte_domain,
"access_token": token_data["access_token"],
}
@@ -248,7 +224,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
def _get_files_list(
self,
path: str,
) -> list[dict[str, Any]]:
) -> Generator[dict[str, Any], None, None]:
if not self.access_token or not self.domain:
raise ConnectorMissingCredentialError("Egnyte")
@@ -260,67 +236,66 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
"list_content": True,
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
response = _request_with_retries(
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
url_encoded_path = quote(path or "")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
response = request_with_retries(
method="GET", url=url, headers=headers, params=params
)
if not response.ok:
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
data = response.json()
all_files: list[dict[str, Any]] = []
# Add files from current directory
all_files.extend(data.get("files", []))
# Yield files from current directory
for file in data.get("files", []):
yield file
# Recursively traverse folders
for item in data.get("folders", []):
all_files.extend(self._get_files_list(item["path"]))
for folder in data.get("folders", []):
yield from self._get_files_list(folder["path"])
return all_files
def _filter_files(
def _should_index_file(
self,
files: list[dict[str, Any]],
file: dict[str, Any],
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> list[dict[str, Any]]:
filtered_files = []
for file in files:
if file["is_folder"]:
continue
) -> bool:
"""Return True if file should be included based on filters."""
if file["is_folder"]:
return False
file_modified = _parse_last_modified(file["last_modified"])
if start_time and file_modified < start_time:
continue
if end_time and file_modified > end_time:
continue
file_modified = _parse_last_modified(file["last_modified"])
if start_time and file_modified < start_time:
return False
if end_time and file_modified > end_time:
return False
filtered_files.append(file)
return filtered_files
return True
def _process_files(
self,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> Generator[list[Document], None, None]:
files = self._get_files_list(self.folder_path)
files = self._filter_files(files, start_time, end_time)
current_batch: list[Document] = []
for file in files:
# Iterate through yielded files and filter them
for file in self._get_files_list(self.folder_path):
if not self._should_index_file(file, start_time, end_time):
logger.debug(f"Skipping file '{file['path']}'.")
continue
try:
# Set up request with streaming enabled
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
response = _request_with_retries(
url_encoded_path = quote(file["path"])
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
response = request_with_retries(
method="GET",
url=url,
headers=headers,
timeout=_TIMEOUT,
stream=True,
)

View File

@@ -5,12 +5,14 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.bookstack.connector import BookstackConnector
from onyx.connectors.clickup.connector import ClickupConnector
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
from onyx.connectors.dropbox.connector import DropboxConnector
@@ -100,9 +102,11 @@ def identify_connector_class(
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.DISCORD: DiscordConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -4,6 +4,7 @@ from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -249,17 +250,36 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
return new_creds_dict
def _get_all_user_emails(self) -> list[str]:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
"""
List all user emails if we are on a Google Workspace domain.
If the domain is gmail.com, or if we attempt to call the Admin SDK and
get a 404, fall back to using the single user.
"""
try:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
except HttpError as e:
if e.resp.status == 404:
logger.warning(
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
"with no Workspace domain. Falling back to single user."
)
return [self.primary_admin_email]
raise
except Exception:
raise
def _fetch_threads(
self,

View File

@@ -8,6 +8,7 @@ from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
@@ -20,6 +21,7 @@ from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.google_auth import get_google_creds
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
@@ -41,6 +43,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
@@ -286,13 +289,30 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
logger.info(f"Impersonating user {user_email}")
drive_service = get_drive_service(self.creds, user_email)
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
if e.status_code == 401:
# fail gracefully, let the other impersonations continue
# one user without access shouldn't block the entire connector
logger.exception(
f"User '{user_email}' does not have access to the drive APIs."
)
return
raise
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - include_my_drives is true
# - the current user's email is in the requested emails
if self.include_my_drives or user_email in self._requested_my_drive_emails:
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
@@ -303,6 +323,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'")
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
@@ -314,6 +335,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
@@ -344,6 +366,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
# checkpoint - we've found all users and drives, now time to actually start
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
logger.debug(f"Users: {all_org_emails}")
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
logger.debug(f"Drives: {drive_ids_to_retrieve}")
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
logger.debug(f"Folders: {folder_ids_to_retrieve}")
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
@@ -380,6 +411,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
drive_service = get_drive_service(self.creds, self.primary_admin_email)
if self.include_files_shared_with_me or self.include_my_drives:
logger.info(
f"Getting shared files/my drive files for OAuth "
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
f"include_my_drives={self.include_my_drives}, "
f"include_shared_drives={self.include_shared_drives}."
f"Using '{self.primary_admin_email}' as the account."
)
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
@@ -412,6 +450,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
drive_ids_to_retrieve = all_drive_ids
for drive_id in drive_ids_to_retrieve:
logger.info(
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
)
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
@@ -425,6 +466,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
for folder_id in remaining_folders:
logger.info(
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
)
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,

View File

@@ -2,6 +2,8 @@ import abc
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
class OAuthConnector(BaseConnector):
class AdditionalOauthKwargs(BaseModel):
# if overridden, all fields should be str type
pass
@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
@@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -7,16 +7,23 @@ from typing import cast
import requests
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import LINEAR_CLIENT_ID
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
logger = setup_logger()
@@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
)
class LinearConnector(LoadConnector, PollConnector):
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@@ -65,8 +72,68 @@ class LinearConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.linear_api_key: str | None = None
@classmethod
def oauth_id(cls) -> DocumentSource:
return DocumentSource.LINEAR
@classmethod
def oauth_authorization_url(
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
) -> str:
if not LINEAR_CLIENT_ID:
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
return (
f"https://linear.app/oauth/authorize"
f"?client_id={LINEAR_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&response_type=code"
f"&scope=read"
f"&state={state}"
)
@classmethod
def oauth_code_to_token(
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
) -> dict[str, Any]:
data = {
"code": code,
"redirect_uri": get_oauth_callback_uri(
base_domain, DocumentSource.LINEAR.value
),
"client_id": LINEAR_CLIENT_ID,
"client_secret": LINEAR_CLIENT_SECRET,
"grant_type": "authorization_code",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = request_with_retries(
method="POST",
url="https://api.linear.app/oauth/token",
data=data,
headers=headers,
backoff=0,
delay=0.1,
)
if not response.ok:
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
token_data = response.json()
return {
"access_token": token_data["access_token"],
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.linear_api_key = cast(str, credentials["linear_api_key"])
if "linear_api_key" in credentials:
self.linear_api_key = cast(str, credentials["linear_api_key"])
elif "access_token" in credentials:
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
else:
# May need to handle case in the future if the OAuth flow expires
raise ConnectorMissingCredentialError("Linear")
return None
def _process_issues(

View File

@@ -101,8 +101,11 @@ class DocumentBase(BaseModel):
source: DocumentSource | None = None
semantic_identifier: str # displayed in the UI as the main identifier for the doc
metadata: dict[str, str | list[str]]
# UTC time
doc_updated_at: datetime | None = None
chunk_count: int | None = None
# Owner, creator, etc.
primary_owners: list[BasicExpertInfo] | None = None
# Assignee, space owner, etc.

View File

@@ -1,42 +1,34 @@
import os
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.salesforce.utils import extract_dict_text
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.utils.logger import setup_logger
# TODO: this connector does not work well at large scales
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
ID_PREFIX = "SALESFORCE_"
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
@@ -44,200 +36,133 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
requested_objects: list[str] = [],
) -> None:
self.batch_size = batch_size
self.sf_client: Salesforce | None = None
self._sf_client: Salesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else DEFAULT_PARENT_OBJECT_TYPES
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.sf_client = Salesforce(
def load_credentials(
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
)
return None
def _get_sf_type_object_json(self, type_name: str) -> Any:
if self.sf_client is None:
@property
def sf_client(self) -> Salesforce:
if self._sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
sf_object = SFType(
type_name, self.sf_client.session_id, self.sf_client.sf_instance
)
return sf_object.describe()
def _get_name_from_id(self, id: str) -> str:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
try:
user_object_info = self.sf_client.query(
f"SELECT Name FROM User WHERE Id = '{id}'"
)
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
return name
except Exception:
logger.warning(f"Couldnt find name for object id: {id}")
return "Null User"
def _convert_object_instance_to_document(
self, object_dict: dict[str, Any]
) -> Document:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_object_text = extract_dict_text(object_dict)
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_primary_owners = [
BasicExpertInfo(
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
)
]
doc = Document(
id=onyx_salesforce_id,
sections=[Section(link=extracted_link, text=extracted_object_text)],
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=extracted_primary_owners,
metadata={},
)
return doc
def _is_valid_child_object(self, child_relationship: dict) -> bool:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = self._get_sf_type_object_json(sf_type)
if not object_description["queryable"]:
return False
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
result = self.sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
children_objects: list[dict] = []
for child_relationship in object_description["childRelationships"]:
if self._is_valid_child_object(child_relationship):
children_objects.append(
{
"relationship_name": child_relationship["relationshipName"],
"object_type": child_relationship["childSObject"],
}
)
return children_objects
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
fields = [
field.get("name")
for field in object_description["fields"]
if field.get("type", "base64") != "base64"
]
return fields
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
"""
This function takes in an object_type and generates query(s) designed to grab
information associated to objects of that type.
It does that by getting all the fields of the parent object type.
Then it gets all the child objects of that object type and all the fields of
those children as well.
"""
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
query = f"SELECT {', '.join(parent_fields)}"
for child_object_dict in child_sf_types:
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
query += f"\n FROM {parent_sf_type}"
yield query
query = "SELECT Id" + query_addition
else:
query += query_addition
query += f"\n FROM {parent_sf_type}"
yield query
return self._sf_client
def _fetch_from_salesforce(
self,
start: datetime | None = None,
end: datetime | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
init_db()
all_object_types: set[str] = set(self.parent_object_list)
doc_batch: list[Document] = []
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
logger.debug(f"Parent object types: {self.parent_object_list}")
# This takes like 20 seconds
for parent_object_type in self.parent_object_list:
logger.debug(f"Processing: {parent_object_type}")
query_results: dict = {}
for query in self._generate_query_per_parent_type(parent_object_type):
if start is not None and end is not None:
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
query_result = self.sf_client.query_all(query)
for record_dict in query_result["records"]:
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
logger.info(
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
child_types = get_all_children_of_sf_type(
self.sf_client, parent_object_type
)
all_object_types.update(child_types)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
)
for combined_object_dict in query_results.values():
doc_batch.append(
self._convert_object_instance_to_document(combined_object_dict)
# Always want to make sure user is grabbed for permissioning purposes
all_object_types.add("User")
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
logger.debug(f"All object types: {all_object_types}")
# checkpoint - we've found all object types, now time to fetch the data
logger.info("Starting to fetch CSVs for all object types")
# This takes like 30 minutes first time and <2 minutes for updates
object_type_to_csv_path = fetch_all_csvs_in_parallel(
sf_client=self.sf_client,
object_types=all_object_types,
start=start,
end=end,
)
updated_ids: set[str] = set()
# This takes like 10 seconds
# This is for testing the rest of the functionality if data has
# already been fetched and put in sqlite
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
# for object_type in self.parent_object_list:
# updated_ids.update(list(find_ids_by_type(object_type)))
# This takes 10-70 minutes first time (idk why the range is so big)
total_types = len(object_type_to_csv_path)
logger.info(f"Starting to process {total_types} object types")
for i, (object_type, csv_paths) in enumerate(
object_type_to_csv_path.items(), 1
):
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
# If path is None, it means it failed to fetch the csv
if csv_paths is None:
continue
# Go through each csv path and use it to update the db
for csv_path in csv_paths:
logger.debug(f"Updating {object_type} with {csv_path}")
new_ids = update_sf_db_with_csv(
object_type=object_type,
csv_download_path=csv_path,
)
updated_ids.update(new_ids)
logger.debug(
f"Added {len(new_ids)} new/updated records for {object_type}"
)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
logger.info(f"Found {len(updated_ids)} total updated records")
logger.info(
f"Starting to process parent objects of types: {self.parent_object_list}"
)
docs_to_yield: list[Document] = []
docs_processed = 0
# Takes 15-20 seconds per batch
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
updated_ids=list(updated_ids),
parent_types=self.parent_object_list,
):
logger.info(
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
)
for parent_id in parent_id_batch:
if not (parent_object := get_record(parent_id, parent_type)):
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
docs_to_yield.append(
convert_sf_object_to_doc(
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
)
docs_processed += 1
if len(docs_to_yield) >= self.batch_size:
yield docs_to_yield
docs_to_yield = []
yield docs_to_yield
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_salesforce()
@@ -245,19 +170,13 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
return self._fetch_from_salesforce(start=start, end=end)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
@@ -274,9 +193,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
if __name__ == "__main__":
connector = SalesforceConnector(
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
)
import time
connector = SalesforceConnector(requested_objects=["Account"])
connector.load_credentials(
{
@@ -285,5 +204,20 @@ if __name__ == "__main__":
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))
start_time = time.time()
doc_count = 0
section_count = 0
text_count = 0
for doc_batch in connector.load_from_state():
doc_count += len(doc_batch)
print(f"doc_count: {doc_count}")
for doc in doc_batch:
section_count += len(doc.sections)
for section in doc.sections:
text_count += len(section.text)
end_time = time.time()
print(f"Doc count: {doc_count}")
print(f"Section count: {section_count}")
print(f"Text count: {text_count}")
print(f"Time taken: {end_time - start_time}")

View File

@@ -0,0 +1,184 @@
import re
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
logger = setup_logger()
ID_PREFIX = "SALESFORCE_"
# All of these types of keys are handled by specific fields in the doc
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
def _clean_salesforce_dict(data: dict | list) -> dict | list:
"""Clean and transform Salesforce API response data by recursively:
1. Extracting records from the response if present
2. Merging attributes into the main dictionary
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
4. Removing '__c' suffix from custom field names
5. Removing None values and empty containers
Args:
data: A dictionary or list from Salesforce API response
Returns:
Cleaned dictionary or list with transformed keys and filtered values
"""
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
# remove the custom object indicator for display
if "__c" in key:
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
# Only add non-empty dictionaries or lists
if filtered_value:
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
# Only add non-empty dictionaries or lists
if filtered_item:
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
"""Convert a nested dictionary or list into a human-readable string format.
Recursively traverses the data structure and formats it with:
- Key-value pairs on separate lines
- Nested structures indented for readability
- Lists and dictionaries handled with appropriate formatting
Args:
data: The dictionary or list to convert
indent: Number of spaces to indent (default: 0)
Returns:
A formatted string representation of the data structure
"""
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent + 2))
return "\n".join(result)
def _extract_dict_text(raw_dict: dict) -> str:
"""Extract text from a Salesforce API response dictionary by:
1. Cleaning the dictionary
2. Converting the cleaned dictionary to natural language
"""
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_for_dict = _json_to_natural_language(processed_dict)
return natural_language_for_dict
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
def _extract_primary_owners(
sf_object: SalesforceObject,
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
logger.warning(f"No LastModifiedById found for {sf_object.id}")
return None
if not (last_modified_by := get_record(last_modified_by_id)):
logger.warning(f"No LastModifiedBy found for {last_modified_by_id}")
return None
user_data = last_modified_by.data
expert_info = BasicExpertInfo(
first_name=user_data.get("FirstName"),
last_name=user_data.get("LastName"),
email=user_data.get("Email"),
display_name=user_data.get("Name"),
)
# Check if all fields are None
if all(
value is None
for value in [
expert_info.first_name,
expert_info.last_name,
expert_info.email,
expert_info.display_name,
]
):
logger.warning(f"No identifying information found for user {user_data}")
return None
return [expert_info]
def convert_sf_object_to_doc(
sf_object: SalesforceObject,
sf_instance: str,
) -> Document:
object_dict = sf_object.data
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
base_url = f"https://{sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
sections = [_extract_section(sf_object, base_url)]
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(_extract_section(child_object, base_url))
doc = Document(
id=onyx_salesforce_id,
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=_extract_primary_owners(sf_object),
metadata={},
)
return doc

View File

@@ -0,0 +1,213 @@
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any
from pytz import UTC
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
from onyx.connectors.salesforce.utils import get_object_type_path
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _build_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
if start is None or end is None:
return ""
start_datetime = datetime.fromtimestamp(start, UTC)
end_datetime = datetime.fromtimestamp(end, UTC)
return (
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
f"AND LastModifiedDate < {end_datetime.isoformat()}"
)
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
return sf_object.describe()
def _is_valid_child_object(
sf_client: Salesforce, child_relationship: dict[str, Any]
) -> bool:
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = _get_sf_type_object_json(sf_client, sf_type)
if not object_description["queryable"]:
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
object_description = _get_sf_type_object_json(sf_client, sf_type)
child_object_types = set()
for child_relationship in object_description["childRelationships"]:
if _is_valid_child_object(sf_client, child_relationship):
logger.debug(
f"Found valid child object {child_relationship['childSObject']}"
)
child_object_types.add(child_relationship["childSObject"])
return child_object_types
def _get_all_queryable_fields_of_sf_type(
sf_client: Salesforce,
sf_type: str,
) -> list[str]:
object_description = _get_sf_type_object_json(sf_client, sf_type)
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
field_names_to_remove: set[str] = set()
for field in fields:
if compound_field_name := field.get("compoundFieldName"):
# We do want to get name fields even if they are compound
if not field.get("nameField"):
field_names_to_remove.add(compound_field_name)
if field.get("type", "base64") == "base64":
continue
if field_name := field.get("name"):
valid_fields.add(field_name)
return list(valid_fields - field_names_to_remove)
def _check_if_object_type_is_empty(
sf_client: Salesforce, sf_type: str, time_filter: str
) -> bool:
"""
Use the rest api to check to make sure the query will result in a non-empty response
"""
try:
query = f"SELECT Count() FROM {sf_type}{time_filter} LIMIT 1"
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
if "OPERATION_TOO_LARGE" not in str(e):
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
return True
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
# Check if the csv already exists
if os.path.exists(get_object_type_path(sf_type)):
existing_csvs = [
os.path.join(get_object_type_path(sf_type), f)
for f in os.listdir(get_object_type_path(sf_type))
if f.endswith(".csv")
]
# If the csv already exists, return the path
# This is likely due to a previous run that failed
# after downloading the csv but before the data was
# written to the db
if existing_csvs:
return existing_csvs
return None
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
return query
def _bulk_retrieve_from_salesforce(
sf_client: Salesforce,
sf_type: str,
time_filter: str,
) -> tuple[str, list[str] | None]:
if not _check_if_object_type_is_empty(sf_client, sf_type, time_filter):
return sf_type, None
if existing_csvs := _check_for_existing_csvs(sf_type):
return sf_type, existing_csvs
query = _build_bulk_query(sf_client, sf_type, time_filter)
bulk_2_handler = SFBulk2Handler(
session_id=sf_client.session_id,
bulk2_url=sf_client.bulk2_url,
proxies=sf_client.proxies,
session=sf_client.session,
)
bulk_2_type = SFBulk2Type(
object_name=sf_type,
bulk2_url=bulk_2_handler.bulk2_url,
headers=bulk_2_handler.headers,
session=bulk_2_handler.session,
)
logger.info(f"Downloading {sf_type}")
logger.info(f"Query: {query}")
try:
# This downloads the file to a file in the target path with a random name
results = bulk_2_type.download(
query=query,
path=get_object_type_path(sf_type),
max_records=1000000,
)
all_download_paths = [result["file"] for result in results]
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
return sf_type, all_download_paths
except Exception as e:
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
return sf_type, None
def fetch_all_csvs_in_parallel(
sf_client: Salesforce,
object_types: set[str],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
) -> dict[str, list[str] | None]:
"""
Fetches all the csvs in parallel for the given object types
Returns a dict of (sf_type, full_download_path)
"""
time_filter = _build_time_filter_for_salesforce(start, end)
time_filter_for_each_object_type = {}
# We do this outside of the thread pool executor because this requires
# a database connection and we don't want to block the thread pool
# executor from running
for sf_type in object_types:
"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
if has_at_least_one_object_of_type(sf_type):
time_filter_for_each_object_type[sf_type] = time_filter
else:
time_filter_for_each_object_type[sf_type] = ""
# Run the bulk retrieve in parallel
with ThreadPoolExecutor() as executor:
results = executor.map(
lambda object_type: _bulk_retrieve_from_salesforce(
sf_client=sf_client,
sf_type=object_type,
time_filter=time_filter_for_each_object_type[object_type],
),
object_types,
)
return dict(results)

View File

@@ -0,0 +1,737 @@
import csv
import os
import shutil
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
get_affected_parent_ids_by_type,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
update_sf_db_with_csv,
)
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
_VALID_SALESFORCE_IDS = [
"001bm00000fd9Z3AAI",
"001bm00000fdYTdAAM",
"001bm00000fdYTeAAM",
"001bm00000fdYTfAAM",
"001bm00000fdYTgAAM",
"001bm00000fdYThAAM",
"001bm00000fdYTiAAM",
"001bm00000fdYTjAAM",
"001bm00000fdYTkAAM",
"001bm00000fdYTlAAM",
"001bm00000fdYTmAAM",
"001bm00000fdYTnAAM",
"001bm00000fdYToAAM",
"500bm00000XoOxtAAF",
"500bm00000XoOxuAAF",
"500bm00000XoOxvAAF",
"500bm00000XoOxwAAF",
"500bm00000XoOxxAAF",
"500bm00000XoOxyAAF",
"500bm00000XoOxzAAF",
"500bm00000XoOy0AAF",
"500bm00000XoOy1AAF",
"500bm00000XoOy2AAF",
"500bm00000XoOy3AAF",
"500bm00000XoOy4AAF",
"500bm00000XoOy5AAF",
"500bm00000XoOy6AAF",
"500bm00000XoOy7AAF",
"500bm00000XoOy8AAF",
"500bm00000XoOy9AAF",
"500bm00000XoOyAAAV",
"500bm00000XoOyBAAV",
"500bm00000XoOyCAAV",
"500bm00000XoOyDAAV",
"500bm00000XoOyEAAV",
"500bm00000XoOyFAAV",
"500bm00000XoOyGAAV",
"500bm00000XoOyHAAV",
"500bm00000XoOyIAAV",
"003bm00000EjHCjAAN",
"003bm00000EjHCkAAN",
"003bm00000EjHClAAN",
"003bm00000EjHCmAAN",
"003bm00000EjHCnAAN",
"003bm00000EjHCoAAN",
"003bm00000EjHCpAAN",
"003bm00000EjHCqAAN",
"003bm00000EjHCrAAN",
"003bm00000EjHCsAAN",
"003bm00000EjHCtAAN",
"003bm00000EjHCuAAN",
"003bm00000EjHCvAAN",
"003bm00000EjHCwAAN",
"003bm00000EjHCxAAN",
"003bm00000EjHCyAAN",
"003bm00000EjHCzAAN",
"003bm00000EjHD0AAN",
"003bm00000EjHD1AAN",
"003bm00000EjHD2AAN",
"550bm00000EXc2tAAD",
"006bm000006kyDpAAI",
"006bm000006kyDqAAI",
"006bm000006kyDrAAI",
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
"006bm000006kyDvAAI",
"006bm000006kyDwAAI",
"006bm000006kyDxAAI",
"006bm000006kyDyAAI",
"006bm000006kyDzAAI",
"006bm000006kyE0AAI",
"006bm000006kyE1AAI",
"006bm000006kyE2AAI",
"006bm000006kyE3AAI",
"006bm000006kyE4AAI",
"006bm000006kyE5AAI",
"006bm000006kyE6AAI",
"006bm000006kyE7AAI",
"006bm000006kyE8AAI",
"006bm000006kyE9AAI",
"006bm000006kyEAAAY",
"006bm000006kyEBAAY",
"006bm000006kyECAAY",
"006bm000006kyEDAAY",
"006bm000006kyEEAAY",
"006bm000006kyEFAAY",
"006bm000006kyEGAAY",
"006bm000006kyEHAAY",
"006bm000006kyEIAAY",
"006bm000006kyEJAAY",
"005bm000009zy0TAAQ",
"005bm000009zy25AAA",
"005bm000009zy26AAA",
"005bm000009zy28AAA",
"005bm000009zy29AAA",
"005bm000009zy2AAAQ",
"005bm000009zy2BAAQ",
]
def clear_sf_db() -> None:
"""
Clears the SF DB by deleting all files in the data directory.
"""
shutil.rmtree(BASE_DATA_PATH)
def create_csv_file(
object_type: str, records: list[dict], filename: str = "test_data.csv"
) -> None:
"""
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
if not records:
return
# Get all unique fields from records
fields: set[str] = set()
for record in records:
fields.update(record.keys())
fields = set(sorted(list(fields))) # Sort for consistent order
# Create CSV file
csv_path = os.path.join(get_object_type_path(object_type), filename)
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for record in records:
writer.writerow(record)
# Update the database with the CSV
update_sf_db_with_csv(object_type, csv_path)
def create_csv_with_example_data() -> None:
"""
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
{
"Id": _VALID_SALESFORCE_IDS[3],
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
{
"Id": _VALID_SALESFORCE_IDS[4],
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
{
"Id": _VALID_SALESFORCE_IDS[5],
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
{
"Id": _VALID_SALESFORCE_IDS[6],
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
{
"Id": _VALID_SALESFORCE_IDS[7],
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"FirstName": "John",
"LastName": "Doe",
"Email": "john.doe@acme.com",
"Title": "CEO",
},
{
"Id": _VALID_SALESFORCE_IDS[41],
"FirstName": "Jane",
"LastName": "Smith",
"Email": "jane.smith@acme.com",
"Title": "CTO",
},
{
"Id": _VALID_SALESFORCE_IDS[42],
"FirstName": "Bob",
"LastName": "Johnson",
"Email": "bob.j@globex.com",
"Title": "Sales Director",
},
{
"Id": _VALID_SALESFORCE_IDS[43],
"FirstName": "Sarah",
"LastName": "Chen",
"Email": "sarah.chen@techcorp.com",
"Title": "Product Manager",
"Phone": "415-555-0101",
},
{
"Id": _VALID_SALESFORCE_IDS[44],
"FirstName": "Michael",
"LastName": "Rodriguez",
"Email": "m.rodriguez@biomed.com",
"Title": "Research Director",
"Phone": "617-555-0202",
},
{
"Id": _VALID_SALESFORCE_IDS[45],
"FirstName": "Emily",
"LastName": "Green",
"Email": "emily.g@greenenergy.com",
"Title": "Sustainability Lead",
"Phone": "503-555-0303",
},
{
"Id": _VALID_SALESFORCE_IDS[46],
"FirstName": "David",
"LastName": "Kim",
"Email": "david.kim@dataflow.com",
"Title": "Data Scientist",
"Phone": "206-555-0404",
},
{
"Id": _VALID_SALESFORCE_IDS[47],
"FirstName": "Rachel",
"LastName": "Taylor",
"Email": "r.taylor@cloudnine.com",
"Title": "Cloud Architect",
"Phone": "303-555-0505",
},
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"Name": "Acme Server Upgrade",
"Amount": 50000,
"Stage": "Prospecting",
"CloseDate": "2024-06-30",
},
{
"Id": _VALID_SALESFORCE_IDS[63],
"Name": "Globex Manufacturing Line",
"Amount": 150000,
"Stage": "Negotiation",
"CloseDate": "2024-03-15",
},
{
"Id": _VALID_SALESFORCE_IDS[64],
"Name": "Initech Software License",
"Amount": 75000,
"Stage": "Closed Won",
"CloseDate": "2024-01-30",
},
{
"Id": _VALID_SALESFORCE_IDS[65],
"Name": "TechCorp AI Implementation",
"Amount": 250000,
"Stage": "Needs Analysis",
"CloseDate": "2024-08-15",
"Probability": 60,
},
{
"Id": _VALID_SALESFORCE_IDS[66],
"Name": "BioMed Lab Equipment",
"Amount": 500000,
"Stage": "Value Proposition",
"CloseDate": "2024-09-30",
"Probability": 75,
},
{
"Id": _VALID_SALESFORCE_IDS[67],
"Name": "Green Energy Solar Project",
"Amount": 750000,
"Stage": "Proposal",
"CloseDate": "2024-07-15",
"Probability": 80,
},
{
"Id": _VALID_SALESFORCE_IDS[68],
"Name": "DataFlow Analytics Platform",
"Amount": 180000,
"Stage": "Negotiation",
"CloseDate": "2024-05-30",
"Probability": 90,
},
{
"Id": _VALID_SALESFORCE_IDS[69],
"Name": "Cloud Nine Infrastructure",
"Amount": 300000,
"Stage": "Qualification",
"CloseDate": "2024-10-15",
"Probability": 40,
},
],
}
# Create CSV files for each object type
for object_type, records in example_data.items():
create_csv_file(object_type, records)
def test_query() -> None:
"""
Tests querying functionality by verifying:
1. All expected Account IDs are found
2. Each Account's data matches what was inserted
"""
# Expected test data for verification
expected_accounts: dict[str, dict[str, str | int]] = {
_VALID_SALESFORCE_IDS[0]: {
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
_VALID_SALESFORCE_IDS[1]: {
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
_VALID_SALESFORCE_IDS[2]: {
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
_VALID_SALESFORCE_IDS[3]: {
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
_VALID_SALESFORCE_IDS[4]: {
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
_VALID_SALESFORCE_IDS[5]: {
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
_VALID_SALESFORCE_IDS[6]: {
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
_VALID_SALESFORCE_IDS[7]: {
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
}
# Get all Account IDs
account_ids = find_ids_by_type("Account")
# Verify we found all expected accounts
assert len(account_ids) == len(
expected_accounts
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
assert set(account_ids) == set(
expected_accounts.keys()
), "Found account IDs don't match expected IDs"
# Verify each account's data
for acc_id in account_ids:
combined = get_record(acc_id)
assert combined is not None, f"Could not find account {acc_id}"
expected = expected_accounts[acc_id]
# Verify account data matches
for key, value in expected.items():
value = str(value)
assert (
combined.data[key] == value
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
print("All query tests passed successfully!")
def test_upsert() -> None:
"""
Tests upsert functionality by:
1. Updating an existing account
2. Creating a new account
3. Verifying both operations were successful
"""
# Create CSV for updating an existing account and adding a new one
update_data: list[dict[str, str | int]] = [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc. Updated",
"BillingCity": "New York",
"Industry": "Technology",
"Description": "Updated company info",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "New Company Inc.",
"BillingCity": "Miami",
"Industry": "Finance",
"AnnualRevenue": 1000000,
},
]
create_csv_file("Account", update_data, "update_data.csv")
# Verify the update worked
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
assert updated_record is not None, "Updated record not found"
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
assert (
updated_record.data["Description"] == "Updated company info"
), "Description not added"
# Verify the new record was created
new_record = get_record(_VALID_SALESFORCE_IDS[2])
assert new_record is not None, "New record not found"
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
print("All upsert tests passed successfully!")
def test_relationships() -> None:
"""
Tests relationship shelf updates and queries by:
1. Creating test data with relationships
2. Verifying the relationships are correctly stored
3. Testing relationship queries
"""
# Create test data for each object type
test_data: dict[str, list[dict[str, str | int]]] = {
"Case": [
{
"Id": _VALID_SALESFORCE_IDS[13],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 1",
},
{
"Id": _VALID_SALESFORCE_IDS[14],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 2",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[48],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Name": "Test Opportunity",
"Amount": 100000,
}
],
}
# Create and update CSV files for each object type
for object_type, records in test_data.items():
create_csv_file(object_type, records, "relationship_test.csv")
# Test relationship queries
# All these objects should be children of Acme Inc.
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
assert (
_VALID_SALESFORCE_IDS[62] in child_ids
), "Opportunity not found in relationship"
# Test querying relationships for a different account (should be empty)
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert (
len(other_account_children) == 0
), "Expected no children for different account"
print("All relationship tests passed successfully!")
def test_account_with_children() -> None:
"""
Tests querying all accounts and retrieving their child objects.
This test verifies that:
1. All accounts can be retrieved
2. Child objects are correctly linked
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = find_ids_by_type("Account")
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
for account_id in account_ids:
account = get_record(account_id)
assert account is not None, f"Could not find account {account_id}"
# Get all child objects
child_ids = get_child_ids(account_id)
# For Acme Inc., verify specific relationships
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
assert (
len(child_ids) == 4
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
# Get all child records
child_records = []
for child_id in child_ids:
child_record = get_record(child_id)
if child_record is not None:
child_records.append(child_record)
# Verify Cases
cases = [r for r in child_records if r.type == "Case"]
assert (
len(cases) == 2
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
case_subjects = {case.data["Subject"] for case in cases}
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
# Verify Contacts
contacts = [r for r in child_records if r.type == "Contact"]
assert (
len(contacts) == 1
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
contact = contacts[0]
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
# Verify Opportunities
opportunities = [r for r in child_records if r.type == "Opportunity"]
assert (
len(opportunities) == 1
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
opportunity = opportunities[0]
assert (
opportunity.data["Name"] == "Test Opportunity"
), "Opportunity name mismatch"
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
print("All account with children tests passed successfully!")
def test_relationship_updates() -> None:
"""
Tests that relationships are properly updated when a child object's parent reference changes.
This test verifies:
1. Initial relationship is created correctly
2. When parent reference is updated, old relationship is removed
3. New relationship is created correctly
"""
# Create initial test data - Contact linked to Acme Inc.
initial_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", initial_contact, "initial_contact.csv")
# Verify initial relationship
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] in acme_children
), "Initial relationship not created"
# Update contact to be linked to Globex Corp instead
updated_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[1],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", updated_contact, "updated_contact.csv")
# Verify old relationship is removed
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] not in acme_children
), "Old relationship not removed"
# Verify new relationship is created
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
print("All relationship update tests passed successfully!")
def test_get_affected_parent_ids() -> None:
"""
Tests get_affected_parent_ids functionality by verifying:
1. IDs that are directly in the parent_types list are included
2. IDs that have children in the updated_ids list are included
3. IDs that are neither of the above are not included
"""
# Create test data with relationships
test_data = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Parent Account 2",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Not Affected Account",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Child",
"LastName": "Contact",
}
],
}
# Create and update CSV files for test data
for object_type, records in test_data.items():
create_csv_file(object_type, records)
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
assert (
_VALID_SALESFORCE_IDS[2] not in affected_ids
), "Unaffected ID incorrectly included"
# Test Case 4: No matches
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Opportunity"] # Wrong type
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 0, "Should return empty list when no matches"
print("All get_affected_parent_ids tests passed successfully!")
def main_build() -> None:
clear_sf_db()
create_csv_with_example_data()
test_query()
test_upsert()
test_relationships()
test_account_with_children()
test_relationship_updates()
test_get_affected_parent_ids()
if __name__ == "__main__":
main_build()

View File

@@ -0,0 +1,209 @@
import csv
import shelve
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
get_child_to_parent_shelf_path,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
get_parent_to_child_shelf_path,
)
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _update_relationship_shelves(
child_id: str,
parent_ids: set[str],
) -> None:
"""Update the relationship shelf when a record is updated."""
try:
# Convert child_id to string once
str_child_id = str(child_id)
# First update child to parent mapping
with shelve.open(
get_child_to_parent_shelf_path(),
flag="c",
protocol=None,
writeback=True,
) as child_to_parent_db:
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
child_to_parent_db[str_child_id] = list(parent_ids)
# Calculate differences outside the next context manager
parent_ids_to_remove = old_parent_ids - parent_ids
parent_ids_to_add = parent_ids - old_parent_ids
# Only sync once at the end
child_to_parent_db.sync()
# Then update parent to child mapping in a single transaction
if not parent_ids_to_remove and not parent_ids_to_add:
return
with shelve.open(
get_parent_to_child_shelf_path(),
flag="c",
protocol=None,
writeback=True,
) as parent_to_child_db:
# Process all removals first
for parent_id in parent_ids_to_remove:
str_parent_id = str(parent_id)
existing_children = set(parent_to_child_db.get(str_parent_id, []))
if str_child_id in existing_children:
existing_children.remove(str_child_id)
parent_to_child_db[str_parent_id] = list(existing_children)
# Then process all additions
for parent_id in parent_ids_to_add:
str_parent_id = str(parent_id)
existing_children = set(parent_to_child_db.get(str_parent_id, []))
existing_children.add(str_child_id)
parent_to_child_db[str_parent_id] = list(existing_children)
# Single sync at the end
parent_to_child_db.sync()
except Exception as e:
logger.error(f"Error updating relationship shelves: {e}")
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
raise
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID.
Args:
parent_id: The ID of the parent object
Returns:
A set of child object IDs
"""
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
return set(parent_to_child_db.get(parent_id, []))
def update_sf_db_with_csv(
object_type: str,
csv_download_path: str,
) -> list[str]:
"""Update the SF DB with a CSV file using shelve storage."""
updated_ids = []
shelf_path = get_object_shelf_path(object_type)
# First read the CSV to get all the data
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
id = row["Id"]
parent_ids = set()
field_to_remove: set[str] = set()
# Update relationship shelves for any parent references
for field, value in row.items():
if validate_salesforce_id(value) and field != "Id":
parent_ids.add(value)
field_to_remove.add(field)
if not value:
field_to_remove.add(field)
_update_relationship_shelves(id, parent_ids)
for field in field_to_remove:
# We use this to extract the Primary Owner later
if field != "LastModifiedById":
del row[field]
# Update the main object shelf
with shelve.open(shelf_path) as object_type_db:
object_type_db[id] = row
# Update the ID-to-type mapping shelf
with shelve.open(get_id_type_shelf_path()) as id_type_db:
id_type_db[id] = object_type
updated_ids.append(id)
# os.remove(csv_download_path)
return updated_ids
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
# Look up the object type from the ID-to-type mapping
with shelve.open(get_id_type_shelf_path()) as id_type_db:
if object_id not in id_type_db:
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
return None
return id_type_db[object_id]
def get_record(
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""
Retrieve the record and return it as a SalesforceObject.
The object type will be looked up from the ID-to-type mapping shelf.
"""
if object_type is None:
if not (object_type := get_type_from_id(object_id)):
return None
shelf_path = get_object_shelf_path(object_type)
with shelve.open(shelf_path) as db:
if object_id not in db:
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
return None
data = db[object_id]
return SalesforceObject(
id=object_id,
type=object_type,
data=data,
)
def find_ids_by_type(object_type: str) -> list[str]:
"""
Find all object IDs for rows of the specified type.
"""
shelf_path = get_object_shelf_path(object_type)
try:
with shelve.open(shelf_path) as db:
return list(db.keys())
except FileNotFoundError:
return []
def get_affected_parent_ids_by_type(
updated_ids: set[str], parent_types: list[str]
) -> dict[str, set[str]]:
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
or have children in the updated_ids.
Args:
updated_ids: List of IDs that were updated
parent_types: List of object types to filter by
Returns:
A dictionary of IDs that match the criteria
"""
affected_ids_by_type: dict[str, set[str]] = {}
# Check each updated ID
for updated_id in updated_ids:
# Add the ID itself if it's of a parent type
updated_type = get_type_from_id(updated_id)
if updated_type in parent_types:
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
continue
# Get parents of this ID and add them if they're of a parent type
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
parent_ids = child_to_parent_db.get(updated_id, [])
for parent_id in parent_ids:
parent_type = get_type_from_id(parent_id)
if parent_type in parent_types:
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
return affected_ids_by_type

View File

@@ -0,0 +1,29 @@
import os
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
def get_object_shelf_path(object_type: str) -> str:
"""Get the path to the shelf file for a specific object type."""
base_path = get_object_type_path(object_type)
os.makedirs(base_path, exist_ok=True)
return os.path.join(base_path, "data.shelf")
def get_id_type_shelf_path() -> str:
"""Get the path to the ID-to-type mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
def get_parent_to_child_shelf_path() -> str:
"""Get the path to the parent-to-child mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
def get_child_to_parent_shelf_path() -> str:
"""Get the path to the child-to-parent mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")

View File

@@ -0,0 +1,737 @@
import csv
import os
import shutil
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
get_affected_parent_ids_by_type,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
update_sf_db_with_csv,
)
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
_VALID_SALESFORCE_IDS = [
"001bm00000fd9Z3AAI",
"001bm00000fdYTdAAM",
"001bm00000fdYTeAAM",
"001bm00000fdYTfAAM",
"001bm00000fdYTgAAM",
"001bm00000fdYThAAM",
"001bm00000fdYTiAAM",
"001bm00000fdYTjAAM",
"001bm00000fdYTkAAM",
"001bm00000fdYTlAAM",
"001bm00000fdYTmAAM",
"001bm00000fdYTnAAM",
"001bm00000fdYToAAM",
"500bm00000XoOxtAAF",
"500bm00000XoOxuAAF",
"500bm00000XoOxvAAF",
"500bm00000XoOxwAAF",
"500bm00000XoOxxAAF",
"500bm00000XoOxyAAF",
"500bm00000XoOxzAAF",
"500bm00000XoOy0AAF",
"500bm00000XoOy1AAF",
"500bm00000XoOy2AAF",
"500bm00000XoOy3AAF",
"500bm00000XoOy4AAF",
"500bm00000XoOy5AAF",
"500bm00000XoOy6AAF",
"500bm00000XoOy7AAF",
"500bm00000XoOy8AAF",
"500bm00000XoOy9AAF",
"500bm00000XoOyAAAV",
"500bm00000XoOyBAAV",
"500bm00000XoOyCAAV",
"500bm00000XoOyDAAV",
"500bm00000XoOyEAAV",
"500bm00000XoOyFAAV",
"500bm00000XoOyGAAV",
"500bm00000XoOyHAAV",
"500bm00000XoOyIAAV",
"003bm00000EjHCjAAN",
"003bm00000EjHCkAAN",
"003bm00000EjHClAAN",
"003bm00000EjHCmAAN",
"003bm00000EjHCnAAN",
"003bm00000EjHCoAAN",
"003bm00000EjHCpAAN",
"003bm00000EjHCqAAN",
"003bm00000EjHCrAAN",
"003bm00000EjHCsAAN",
"003bm00000EjHCtAAN",
"003bm00000EjHCuAAN",
"003bm00000EjHCvAAN",
"003bm00000EjHCwAAN",
"003bm00000EjHCxAAN",
"003bm00000EjHCyAAN",
"003bm00000EjHCzAAN",
"003bm00000EjHD0AAN",
"003bm00000EjHD1AAN",
"003bm00000EjHD2AAN",
"550bm00000EXc2tAAD",
"006bm000006kyDpAAI",
"006bm000006kyDqAAI",
"006bm000006kyDrAAI",
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
"006bm000006kyDvAAI",
"006bm000006kyDwAAI",
"006bm000006kyDxAAI",
"006bm000006kyDyAAI",
"006bm000006kyDzAAI",
"006bm000006kyE0AAI",
"006bm000006kyE1AAI",
"006bm000006kyE2AAI",
"006bm000006kyE3AAI",
"006bm000006kyE4AAI",
"006bm000006kyE5AAI",
"006bm000006kyE6AAI",
"006bm000006kyE7AAI",
"006bm000006kyE8AAI",
"006bm000006kyE9AAI",
"006bm000006kyEAAAY",
"006bm000006kyEBAAY",
"006bm000006kyECAAY",
"006bm000006kyEDAAY",
"006bm000006kyEEAAY",
"006bm000006kyEFAAY",
"006bm000006kyEGAAY",
"006bm000006kyEHAAY",
"006bm000006kyEIAAY",
"006bm000006kyEJAAY",
"005bm000009zy0TAAQ",
"005bm000009zy25AAA",
"005bm000009zy26AAA",
"005bm000009zy28AAA",
"005bm000009zy29AAA",
"005bm000009zy2AAAQ",
"005bm000009zy2BAAQ",
]
def clear_sf_db() -> None:
"""
Clears the SF DB by deleting all files in the data directory.
"""
shutil.rmtree(BASE_DATA_PATH)
def create_csv_file(
object_type: str, records: list[dict], filename: str = "test_data.csv"
) -> None:
"""
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
if not records:
return
# Get all unique fields from records
fields: set[str] = set()
for record in records:
fields.update(record.keys())
fields = set(sorted(list(fields))) # Sort for consistent order
# Create CSV file
csv_path = os.path.join(get_object_type_path(object_type), filename)
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for record in records:
writer.writerow(record)
# Update the database with the CSV
update_sf_db_with_csv(object_type, csv_path)
def create_csv_with_example_data() -> None:
"""
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
{
"Id": _VALID_SALESFORCE_IDS[3],
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
{
"Id": _VALID_SALESFORCE_IDS[4],
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
{
"Id": _VALID_SALESFORCE_IDS[5],
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
{
"Id": _VALID_SALESFORCE_IDS[6],
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
{
"Id": _VALID_SALESFORCE_IDS[7],
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"FirstName": "John",
"LastName": "Doe",
"Email": "john.doe@acme.com",
"Title": "CEO",
},
{
"Id": _VALID_SALESFORCE_IDS[41],
"FirstName": "Jane",
"LastName": "Smith",
"Email": "jane.smith@acme.com",
"Title": "CTO",
},
{
"Id": _VALID_SALESFORCE_IDS[42],
"FirstName": "Bob",
"LastName": "Johnson",
"Email": "bob.j@globex.com",
"Title": "Sales Director",
},
{
"Id": _VALID_SALESFORCE_IDS[43],
"FirstName": "Sarah",
"LastName": "Chen",
"Email": "sarah.chen@techcorp.com",
"Title": "Product Manager",
"Phone": "415-555-0101",
},
{
"Id": _VALID_SALESFORCE_IDS[44],
"FirstName": "Michael",
"LastName": "Rodriguez",
"Email": "m.rodriguez@biomed.com",
"Title": "Research Director",
"Phone": "617-555-0202",
},
{
"Id": _VALID_SALESFORCE_IDS[45],
"FirstName": "Emily",
"LastName": "Green",
"Email": "emily.g@greenenergy.com",
"Title": "Sustainability Lead",
"Phone": "503-555-0303",
},
{
"Id": _VALID_SALESFORCE_IDS[46],
"FirstName": "David",
"LastName": "Kim",
"Email": "david.kim@dataflow.com",
"Title": "Data Scientist",
"Phone": "206-555-0404",
},
{
"Id": _VALID_SALESFORCE_IDS[47],
"FirstName": "Rachel",
"LastName": "Taylor",
"Email": "r.taylor@cloudnine.com",
"Title": "Cloud Architect",
"Phone": "303-555-0505",
},
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"Name": "Acme Server Upgrade",
"Amount": 50000,
"Stage": "Prospecting",
"CloseDate": "2024-06-30",
},
{
"Id": _VALID_SALESFORCE_IDS[63],
"Name": "Globex Manufacturing Line",
"Amount": 150000,
"Stage": "Negotiation",
"CloseDate": "2024-03-15",
},
{
"Id": _VALID_SALESFORCE_IDS[64],
"Name": "Initech Software License",
"Amount": 75000,
"Stage": "Closed Won",
"CloseDate": "2024-01-30",
},
{
"Id": _VALID_SALESFORCE_IDS[65],
"Name": "TechCorp AI Implementation",
"Amount": 250000,
"Stage": "Needs Analysis",
"CloseDate": "2024-08-15",
"Probability": 60,
},
{
"Id": _VALID_SALESFORCE_IDS[66],
"Name": "BioMed Lab Equipment",
"Amount": 500000,
"Stage": "Value Proposition",
"CloseDate": "2024-09-30",
"Probability": 75,
},
{
"Id": _VALID_SALESFORCE_IDS[67],
"Name": "Green Energy Solar Project",
"Amount": 750000,
"Stage": "Proposal",
"CloseDate": "2024-07-15",
"Probability": 80,
},
{
"Id": _VALID_SALESFORCE_IDS[68],
"Name": "DataFlow Analytics Platform",
"Amount": 180000,
"Stage": "Negotiation",
"CloseDate": "2024-05-30",
"Probability": 90,
},
{
"Id": _VALID_SALESFORCE_IDS[69],
"Name": "Cloud Nine Infrastructure",
"Amount": 300000,
"Stage": "Qualification",
"CloseDate": "2024-10-15",
"Probability": 40,
},
],
}
# Create CSV files for each object type
for object_type, records in example_data.items():
create_csv_file(object_type, records)
def test_query() -> None:
"""
Tests querying functionality by verifying:
1. All expected Account IDs are found
2. Each Account's data matches what was inserted
"""
# Expected test data for verification
expected_accounts: dict[str, dict[str, str | int]] = {
_VALID_SALESFORCE_IDS[0]: {
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
_VALID_SALESFORCE_IDS[1]: {
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
_VALID_SALESFORCE_IDS[2]: {
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
_VALID_SALESFORCE_IDS[3]: {
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
_VALID_SALESFORCE_IDS[4]: {
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
_VALID_SALESFORCE_IDS[5]: {
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
_VALID_SALESFORCE_IDS[6]: {
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
_VALID_SALESFORCE_IDS[7]: {
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
}
# Get all Account IDs
account_ids = find_ids_by_type("Account")
# Verify we found all expected accounts
assert len(account_ids) == len(
expected_accounts
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
assert set(account_ids) == set(
expected_accounts.keys()
), "Found account IDs don't match expected IDs"
# Verify each account's data
for acc_id in account_ids:
combined = get_record(acc_id)
assert combined is not None, f"Could not find account {acc_id}"
expected = expected_accounts[acc_id]
# Verify account data matches
for key, value in expected.items():
value = str(value)
assert (
combined.data[key] == value
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
print("All query tests passed successfully!")
def test_upsert() -> None:
"""
Tests upsert functionality by:
1. Updating an existing account
2. Creating a new account
3. Verifying both operations were successful
"""
# Create CSV for updating an existing account and adding a new one
update_data: list[dict[str, str | int]] = [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc. Updated",
"BillingCity": "New York",
"Industry": "Technology",
"Description": "Updated company info",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "New Company Inc.",
"BillingCity": "Miami",
"Industry": "Finance",
"AnnualRevenue": 1000000,
},
]
create_csv_file("Account", update_data, "update_data.csv")
# Verify the update worked
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
assert updated_record is not None, "Updated record not found"
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
assert (
updated_record.data["Description"] == "Updated company info"
), "Description not added"
# Verify the new record was created
new_record = get_record(_VALID_SALESFORCE_IDS[2])
assert new_record is not None, "New record not found"
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
print("All upsert tests passed successfully!")
def test_relationships() -> None:
"""
Tests relationship shelf updates and queries by:
1. Creating test data with relationships
2. Verifying the relationships are correctly stored
3. Testing relationship queries
"""
# Create test data for each object type
test_data: dict[str, list[dict[str, str | int]]] = {
"Case": [
{
"Id": _VALID_SALESFORCE_IDS[13],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 1",
},
{
"Id": _VALID_SALESFORCE_IDS[14],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 2",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[48],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Name": "Test Opportunity",
"Amount": 100000,
}
],
}
# Create and update CSV files for each object type
for object_type, records in test_data.items():
create_csv_file(object_type, records, "relationship_test.csv")
# Test relationship queries
# All these objects should be children of Acme Inc.
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
assert (
_VALID_SALESFORCE_IDS[62] in child_ids
), "Opportunity not found in relationship"
# Test querying relationships for a different account (should be empty)
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert (
len(other_account_children) == 0
), "Expected no children for different account"
print("All relationship tests passed successfully!")
def test_account_with_children() -> None:
"""
Tests querying all accounts and retrieving their child objects.
This test verifies that:
1. All accounts can be retrieved
2. Child objects are correctly linked
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = find_ids_by_type("Account")
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
for account_id in account_ids:
account = get_record(account_id)
assert account is not None, f"Could not find account {account_id}"
# Get all child objects
child_ids = get_child_ids(account_id)
# For Acme Inc., verify specific relationships
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
assert (
len(child_ids) == 4
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
# Get all child records
child_records = []
for child_id in child_ids:
child_record = get_record(child_id)
if child_record is not None:
child_records.append(child_record)
# Verify Cases
cases = [r for r in child_records if r.type == "Case"]
assert (
len(cases) == 2
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
case_subjects = {case.data["Subject"] for case in cases}
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
# Verify Contacts
contacts = [r for r in child_records if r.type == "Contact"]
assert (
len(contacts) == 1
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
contact = contacts[0]
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
# Verify Opportunities
opportunities = [r for r in child_records if r.type == "Opportunity"]
assert (
len(opportunities) == 1
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
opportunity = opportunities[0]
assert (
opportunity.data["Name"] == "Test Opportunity"
), "Opportunity name mismatch"
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
print("All account with children tests passed successfully!")
def test_relationship_updates() -> None:
"""
Tests that relationships are properly updated when a child object's parent reference changes.
This test verifies:
1. Initial relationship is created correctly
2. When parent reference is updated, old relationship is removed
3. New relationship is created correctly
"""
# Create initial test data - Contact linked to Acme Inc.
initial_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", initial_contact, "initial_contact.csv")
# Verify initial relationship
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] in acme_children
), "Initial relationship not created"
# Update contact to be linked to Globex Corp instead
updated_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[1],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", updated_contact, "updated_contact.csv")
# Verify old relationship is removed
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] not in acme_children
), "Old relationship not removed"
# Verify new relationship is created
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
print("All relationship update tests passed successfully!")
def test_get_affected_parent_ids() -> None:
"""
Tests get_affected_parent_ids functionality by verifying:
1. IDs that are directly in the parent_types list are included
2. IDs that have children in the updated_ids list are included
3. IDs that are neither of the above are not included
"""
# Create test data with relationships
test_data = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Parent Account 2",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Not Affected Account",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Child",
"LastName": "Contact",
}
],
}
# Create and update CSV files for test data
for object_type, records in test_data.items():
create_csv_file(object_type, records)
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
assert (
_VALID_SALESFORCE_IDS[2] not in affected_ids
), "Unaffected ID incorrectly included"
# Test Case 4: No matches
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Opportunity"] # Wrong type
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 0, "Should return empty list when no matches"
print("All get_affected_parent_ids tests passed successfully!")
def main_build() -> None:
clear_sf_db()
create_csv_with_example_data()
test_query()
test_upsert()
test_relationships()
test_account_with_children()
test_relationship_updates()
test_get_affected_parent_ids()
if __name__ == "__main__":
main_build()

View File

@@ -0,0 +1,475 @@
import csv
import json
import os
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
@contextmanager
def get_db_connection(
isolation_level: str | None = None,
) -> Iterator[sqlite3.Connection]:
"""Get a database connection with proper isolation level and error handling.
Args:
isolation_level: SQLite isolation level. None = default "DEFERRED",
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
"""
# 60 second timeout for locks
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
if isolation_level is not None:
conn.isolation_level = isolation_level
try:
yield conn
except Exception:
conn.rollback()
raise
finally:
conn.close()
def init_db() -> None:
"""Initialize the SQLite database with required tables if they don't exist."""
# Create database directory if it doesn't exist
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
with get_db_connection("EXCLUSIVE") as conn:
cursor = conn.cursor()
db_exists = os.path.exists(get_sqlite_db_path())
if not db_exists:
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
# Main table for storing Salesforce objects
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS salesforce_objects (
id TEXT PRIMARY KEY,
object_type TEXT NOT NULL,
data TEXT NOT NULL, -- JSON serialized data
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
) WITHOUT ROWID -- Optimize for primary key lookups
"""
)
# Table for parent-child relationships with covering index
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relationships (
child_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
PRIMARY KEY (child_id, parent_id)
) WITHOUT ROWID -- Optimize for primary key lookups
"""
)
# New table for caching parent-child relationships with object types
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relationship_types (
child_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
parent_type TEXT NOT NULL,
PRIMARY KEY (child_id, parent_id, parent_type)
) WITHOUT ROWID
"""
)
# Create a table for User email to ID mapping if it doesn't exist
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_email_map (
email TEXT PRIMARY KEY,
user_id TEXT, -- Nullable to allow for users without IDs
FOREIGN KEY (user_id) REFERENCES salesforce_objects(id)
) WITHOUT ROWID
"""
)
# Create indexes if they don't exist (SQLite ignores IF NOT EXISTS for indexes)
def create_index_if_not_exists(index_name: str, create_statement: str) -> None:
cursor.execute(
f"SELECT name FROM sqlite_master WHERE type='index' AND name='{index_name}'"
)
if not cursor.fetchone():
cursor.execute(create_statement)
create_index_if_not_exists(
"idx_object_type",
"""
CREATE INDEX idx_object_type
ON salesforce_objects(object_type, id)
WHERE object_type IS NOT NULL
""",
)
create_index_if_not_exists(
"idx_parent_id",
"""
CREATE INDEX idx_parent_id
ON relationships(parent_id, child_id)
""",
)
create_index_if_not_exists(
"idx_child_parent",
"""
CREATE INDEX idx_child_parent
ON relationships(child_id)
WHERE child_id IS NOT NULL
""",
)
create_index_if_not_exists(
"idx_relationship_types_lookup",
"""
CREATE INDEX idx_relationship_types_lookup
ON relationship_types(parent_type, child_id, parent_id)
""",
)
# Analyze tables to help query planner
cursor.execute("ANALYZE relationships")
cursor.execute("ANALYZE salesforce_objects")
cursor.execute("ANALYZE relationship_types")
cursor.execute("ANALYZE user_email_map")
# If database already existed but user_email_map needs to be populated
cursor.execute("SELECT COUNT(*) FROM user_email_map")
if cursor.fetchone()[0] == 0:
_update_user_email_map(conn)
conn.commit()
def _update_relationship_tables(
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
) -> None:
"""Update the relationship tables when a record is updated.
Args:
conn: The database connection to use (must be in a transaction)
child_id: The ID of the child record
parent_ids: Set of parent IDs to link to
"""
try:
cursor = conn.cursor()
# Get existing parent IDs
cursor.execute(
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
)
old_parent_ids = {row[0] for row in cursor.fetchall()}
# Calculate differences
parent_ids_to_remove = old_parent_ids - parent_ids
parent_ids_to_add = parent_ids - old_parent_ids
# Remove old relationships
if parent_ids_to_remove:
cursor.executemany(
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
[(child_id, pid) for pid in parent_ids_to_remove],
)
# Also remove from relationship_types
cursor.executemany(
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
[(child_id, pid) for pid in parent_ids_to_remove],
)
# Add new relationships
if parent_ids_to_add:
# First add to relationships table
cursor.executemany(
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
[(child_id, pid) for pid in parent_ids_to_add],
)
# Then get the types of the parent objects and add to relationship_types
for parent_id in parent_ids_to_add:
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?",
(parent_id,),
)
result = cursor.fetchone()
if result:
parent_type = result[0]
cursor.execute(
"""
INSERT INTO relationship_types (child_id, parent_id, parent_type)
VALUES (?, ?, ?)
""",
(child_id, parent_id, parent_type),
)
except Exception as e:
logger.error(f"Error updating relationship tables: {e}")
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
raise
def _update_user_email_map(conn: sqlite3.Connection) -> None:
"""Update the user_email_map table with current User objects.
Called internally by update_sf_db_with_csv when User objects are updated.
"""
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO user_email_map (email, user_id)
SELECT json_extract(data, '$.Email'), id
FROM salesforce_objects
WHERE object_type = 'User'
AND json_extract(data, '$.Email') IS NOT NULL
"""
)
def update_sf_db_with_csv(
object_type: str,
csv_download_path: str,
delete_csv_after_use: bool = True,
) -> list[str]:
"""Update the SF DB with a CSV file using SQLite storage."""
updated_ids = []
# Use IMMEDIATE to get a write lock at the start of the transaction
with get_db_connection("IMMEDIATE") as conn:
cursor = conn.cursor()
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
if "Id" not in row:
logger.warning(
f"Row {row} does not have an Id field in {csv_download_path}"
)
continue
id = row["Id"]
parent_ids = set()
field_to_remove: set[str] = set()
# Process relationships and clean data
for field, value in row.items():
if validate_salesforce_id(value) and field != "Id":
parent_ids.add(value)
field_to_remove.add(field)
if not value:
field_to_remove.add(field)
# Remove unwanted fields
for field in field_to_remove:
if field != "LastModifiedById":
del row[field]
# Update main object data
cursor.execute(
"""
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
VALUES (?, ?, ?)
""",
(id, object_type, json.dumps(row)),
)
# Update relationships using the same connection
_update_relationship_tables(conn, id, parent_ids)
updated_ids.append(id)
# If we're updating User objects, update the email map
if object_type == "User":
_update_user_email_map(conn)
conn.commit()
if delete_csv_after_use:
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_download_path)
return updated_ids
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID."""
with get_db_connection() as conn:
cursor = conn.cursor()
# Force index usage with INDEXED BY
cursor.execute(
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
(parent_id,),
)
child_ids = {row[0] for row in cursor.fetchall()}
return child_ids
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
)
result = cursor.fetchone()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
return result[0]
def get_record(
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""Retrieve the record and return it as a SalesforceObject."""
if object_type is None:
object_type = get_type_from_id(object_id)
if not object_type:
return None
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
result = cursor.fetchone()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
data = json.loads(result[0])
return SalesforceObject(id=object_id, type=object_type, data=data)
def find_ids_by_type(object_type: str) -> list[str]:
"""Find all object IDs for rows of the specified type."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
)
return [row[0] for row in cursor.fetchall()]
def get_affected_parent_ids_by_type(
updated_ids: list[str],
parent_types: list[str],
batch_size: int = 500,
) -> Iterator[tuple[str, set[str]]]:
"""Get IDs of objects that are of the specified parent types and are either in the
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
"""
# SQLite typically has a limit of 999 variables
updated_ids_batches = batch_list(updated_ids, batch_size)
updated_parent_ids: set[str] = set()
with get_db_connection() as conn:
cursor = conn.cursor()
for batch_ids in updated_ids_batches:
batch_ids = list(set(batch_ids) - updated_parent_ids)
if not batch_ids:
continue
id_placeholders = ",".join(["?" for _ in batch_ids])
for parent_type in parent_types:
affected_ids: set[str] = set()
# Get directly updated objects of parent types - using index on object_type
cursor.execute(
f"""
SELECT id FROM salesforce_objects
WHERE id IN ({id_placeholders})
AND object_type = ?
""",
batch_ids + [parent_type],
)
affected_ids.update(row[0] for row in cursor.fetchall())
# Get parent objects of updated objects - using optimized relationship_types table
cursor.execute(
f"""
SELECT DISTINCT parent_id
FROM relationship_types
INDEXED BY idx_relationship_types_lookup
WHERE parent_type = ?
AND child_id IN ({id_placeholders})
""",
[parent_type] + batch_ids,
)
affected_ids.update(row[0] for row in cursor.fetchall())
# Remove any parent IDs that have already been processed
new_affected_ids = affected_ids - updated_parent_ids
# Add the new affected IDs to the set of updated parent IDs
updated_parent_ids.update(new_affected_ids)
if new_affected_ids:
yield parent_type, new_affected_ids
def has_at_least_one_object_of_type(object_type: str) -> bool:
"""Check if there is at least one object of the specified type in the database.
Args:
object_type: The Salesforce object type to check
Returns:
bool: True if at least one object exists, False otherwise
"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
(object_type,),
)
count = cursor.fetchone()[0]
return count > 0
# NULL_ID_STRING is used to indicate that the user ID was queried but not found
# As opposed to None because it has yet to be queried at all
NULL_ID_STRING = "N/A"
def get_user_id_by_email(email: str) -> str | None:
"""Get the Salesforce User ID for a given email address.
Args:
email: The email address to look up
Returns:
A tuple of (was_found, user_id):
- was_found: True if the email exists in the table, False if not found
- user_id: The Salesforce User ID if exists, None otherwise
"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT user_id FROM user_email_map WHERE email = ?", (email,))
result = cursor.fetchone()
if result is None:
return None
return result[0]
def update_email_to_id_table(email: str, id: str | None) -> None:
"""Update the email to ID map table with a new email and ID."""
id_to_use = id or NULL_ID_STRING
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO user_email_map (email, user_id) VALUES (?, ?)",
(email, id_to_use),
)
conn.commit()

View File

@@ -1,66 +1,72 @@
import re
from typing import Union
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
import os
from dataclasses import dataclass
from typing import Any
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
@dataclass
class SalesforceObject:
id: str
type: str
data: dict[str, Any]
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
if "__c" in key: # remove the custom object indicator for display
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
if filtered_value: # Only add non-empty dictionaries or lists
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
if filtered_item: # Only add non-empty dictionaries or lists
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def to_dict(self) -> dict[str, Any]:
return {
"ID": self.id,
"Type": self.type,
"Data": self.data,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
return cls(
id=data["Id"],
type=data["Type"],
data=data,
)
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent))
else:
result.append(f"{indent_str}{data}")
return "\n".join(result)
# This defines the base path for all data files relative to this file
# AKA BE CAREFUL WHEN MOVING THIS FILE
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
def extract_dict_text(raw_dict: dict) -> str:
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_dict = _json_to_natural_language(processed_dict)
return natural_language_dict
def get_sqlite_db_path() -> str:
"""Get the path to the sqlite db file."""
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
def get_object_type_path(object_type: str) -> str:
"""Get the directory path for a specific object type."""
type_dir = os.path.join(BASE_DATA_PATH, object_type)
os.makedirs(type_dir, exist_ok=True)
return type_dir
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
def validate_salesforce_id(salesforce_id: str) -> bool:
"""Validate the checksum portion of an 18-character Salesforce ID.
Args:
salesforce_id: An 18-character Salesforce ID
Returns:
bool: True if the checksum is valid, False otherwise
"""
if len(salesforce_id) != 18:
return False
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
checksum = salesforce_id[15:18]
calculated_checksum = ""
for chunk in chunks:
result_string = "".join(
"1" if char.isupper() else "0" for char in reversed(chunk)
)
calculated_checksum += _LOOKUP[result_string]
return checksum == calculated_checksum

View File

@@ -264,24 +264,6 @@ class SlackTextCleaner:
message = message.replace("<!everyone>", "@everyone")
return message
@staticmethod
def replace_links(message: str) -> str:
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
# Find user IDs in the message
possible_link_matches = re.findall(r"<(.*?)>", message)
for possible_link in possible_link_matches:
if not possible_link:
continue
# Special slack patterns that aren't for links
if possible_link[0] not in ["#", "@", "!"]:
link_display = (
possible_link
if "|" not in possible_link
else possible_link.split("|")[1]
)
message = message.replace(f"<{possible_link}>", link_display)
return message
@staticmethod
def replace_special_catchall(message: str) -> str:
"""Replaces pattern of <!something|another-thing> with another-thing

View File

@@ -33,6 +33,7 @@ from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
from onyx.utils.sitemap import list_pages_for_site
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -241,6 +242,12 @@ class WebConnector(LoadConnector):
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
# Explicitly check if running in multi-tenant mode to prevent potential security risks
if MULTI_TENANT:
raise ValueError(
"Upload input for web connector is not supported in cloud environments"
)
logger.warning(
"This is not a UI supported Web Connector flow, "
"are you sure you want to do this?"

View File

@@ -10,17 +10,21 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
time_str_to_utc,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.retry_wrapper import retry_builder
MAX_PAGE_SIZE = 30 # Zendesk API maximum
_SLIM_BATCH_SIZE = 1000
class ZendeskCredentialsNotSetUpError(PermissionError):
@@ -40,6 +44,13 @@ class ZendeskClient:
response = requests.get(
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
)
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
if retry_after is not None:
# Sleep for the duration indicated by the Retry-After header
time.sleep(int(retry_after))
response.raise_for_status()
return response.json()
@@ -265,7 +276,7 @@ def _ticket_to_document(
)
class ZendeskConnector(LoadConnector, PollConnector):
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@@ -390,6 +401,43 @@ class ZendeskConnector(LoadConnector, PollConnector):
if doc_batch:
yield doc_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
if self.content_type == "articles":
articles = _get_articles(
self.client, start_time=int(start) if start else None
)
for article in articles:
slim_doc_batch.append(
SlimDocument(
id=f"article:{article['id']}",
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
elif self.content_type == "tickets":
tickets = _get_tickets(
self.client, start_time=int(start) if start else None
)
for ticket in tickets:
slim_doc_batch.append(
SlimDocument(
id=f"zendesk_ticket_{ticket['id']}",
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
else:
raise ValueError(f"Unsupported content_type: {self.content_type}")
if slim_doc_batch:
yield slim_doc_batch
if __name__ == "__main__":
import os

View File

@@ -37,6 +37,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.timing import log_function_time
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -163,6 +164,17 @@ class SearchPipeline:
# These chunks are ordered, deduped, and contain no large chunks
retrieved_chunks = self._get_chunks()
# If ee is enabled, censor the chunk sections based on user access
# Otherwise, return the retrieved chunks
censored_chunks = fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring",
"_post_query_chunk_censoring",
retrieved_chunks,
)(
chunks=retrieved_chunks,
user=self.user,
)
above = self.search_query.chunks_above
below = self.search_query.chunks_below
@@ -175,7 +187,7 @@ class SearchPipeline:
seen_document_ids = set()
# This preserves the ordering since the chunks are retrieved in score order
for chunk in retrieved_chunks:
for chunk in censored_chunks:
if chunk.document_id not in seen_document_ids:
seen_document_ids.add(chunk.document_id)
chunk_requests.append(
@@ -225,7 +237,7 @@ class SearchPipeline:
# This maintains the original chunks ordering. Note, we cannot simply sort by score here
# as reranking flow may wipe the scores for a lot of the chunks.
doc_chunk_ranges_map = defaultdict(list)
for chunk in retrieved_chunks:
for chunk in censored_chunks:
# The list of ranges for each document is ordered by score
doc_chunk_ranges_map[chunk.document_id].append(
ChunkRange(
@@ -274,11 +286,11 @@ class SearchPipeline:
# In case of failed parallel calls to Vespa, at least we should have the initial retrieved chunks
doc_chunk_ind_to_chunk.update(
{(chunk.document_id, chunk.chunk_id): chunk for chunk in retrieved_chunks}
{(chunk.document_id, chunk.chunk_id): chunk for chunk in censored_chunks}
)
# Build the surroundings for all of the initial retrieved chunks
for chunk in retrieved_chunks:
for chunk in censored_chunks:
start_ind = max(0, chunk.chunk_id - above)
end_ind = chunk.chunk_id + below

View File

@@ -54,9 +54,11 @@ def get_total_users_count(db_session: Session) -> int:
return user_count + invited_users
async def get_user_count() -> int:
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)
user_count = result.scalar()
if user_count is None:

View File

@@ -7,6 +7,7 @@ from sqlalchemy import exists
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
@@ -90,15 +91,22 @@ def get_connector_credential_pairs(
user: User | None = None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = _add_user_filters(stmt, user, get_editable)
if not include_disabled:
stmt = stmt.where(
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
) # noqa
)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
@@ -310,6 +318,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
if existing_association is not None:
return
# DefaultCCPair has id 1 since it is the first CC pair created
# It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the
# auto-incrementing id
association = ConnectorCredentialPair(
connector_id=0,
credential_id=0,
@@ -350,7 +361,12 @@ def add_credential_to_connector(
last_successful_index_time: datetime | None = None,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
@@ -427,7 +443,12 @@ def remove_credential_from_connector(
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")

View File

@@ -86,7 +86,7 @@ def _add_user_filters(
"""
Filter Credentials by:
- if the user is in the user_group that owns the Credential
- if the user is not a global_curator, they must also have a curator relationship
- if the user is a curator, they must also have a curator relationship
to the user_group
- if editing is being done, we also filter out Credentials that are owned by groups
that the user isn't a curator for
@@ -97,6 +97,7 @@ def _add_user_filters(
where_clause = User__UserGroup.user_id == user.id
if user.role == UserRole.CURATOR:
where_clause &= User__UserGroup.is_curator == True # noqa: E712
if get_editable:
user_groups = select(User__UserGroup.user_group_id).where(
User__UserGroup.user_id == user.id
@@ -152,10 +153,16 @@ def fetch_credential_by_id(
user: User | None,
db_session: Session,
assume_admin: bool = False,
get_editable: bool = True,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
stmt = _add_user_filters(stmt, user, assume_admin=assume_admin)
stmt = _add_user_filters(
stmt=stmt,
user=user,
assume_admin=assume_admin,
get_editable=get_editable,
)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential

Some files were not shown because too many files have changed in this diff Show More