Compare commits

..

137 Commits

Author SHA1 Message Date
pablodanswer
da86610022 nit 2025-01-06 18:36:38 -08:00
pablodanswer
0027759dbf nit 2025-01-06 17:55:02 -08:00
pablodanswer
595ef152d2 updated UX 2025-01-05 14:38:52 -08:00
pablodanswer
083d669d1b minor logic update 2025-01-05 13:22:15 -08:00
pablodanswer
3ac31136b2 base functional 2025-01-03 17:11:14 -08:00
pablodanswer
a73a438d95 k 2025-01-03 14:33:24 -08:00
pablodanswer
c0770481e8 finalize 2025-01-03 14:29:52 -08:00
pablodanswer
c27d13c07f rm danswer 2025-01-03 14:27:25 -08:00
pablodanswer
ab34c4e772 add my docs v1 2025-01-03 14:25:56 -08:00
rkuo-danswer
66f9124135 Merge pull request #3584 from onyx-dot-app/bugfix/log_spacing
fix formatting
2025-01-03 00:43:36 -08:00
Richard Kuo
8f0fb70bbf fix formatting 2025-01-02 23:21:54 -08:00
rkuo-danswer
ef5e5c80bb Merge pull request #3577 from onyx-dot-app/bugfix/model_server_exception_logging
fix response logging
2025-01-02 23:08:46 -08:00
rkuo-danswer
03acb6587a Feature/model server logging (#3579)
* improve model server logging

* improve exception logging with provider/model names

* get everything into one log line

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-01-03 01:40:29 +00:00
hagen-danswer
d1ec72b5e5 Reworked salesforce connector to use bulk api (#3581) 2025-01-02 18:09:02 -08:00
Weves
3b214133a8 Airtable improvement 2025-01-02 17:56:05 -08:00
rkuo-danswer
2232702e99 retry the individual delete's (#3580)
* retry the individual delete's

* need to raise inside the retry

* just use retry for now

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-01-02 17:39:37 -08:00
hagen-danswer
8108ff0a4b Added logging for permissions upsert queue length 2025-01-02 17:39:01 -08:00
Richard Kuo
f64e78e986 fix response logging 2025-01-02 13:39:19 -08:00
Chris Weaver
08312a4394 Update Slack link in README.md 2025-01-01 10:03:59 -08:00
Weves
92add655e0 Slack fixes 2024-12-31 18:04:12 -08:00
Chris Weaver
d64464ca7c Add support for OAuth connectors that require user input (#3571)
* Add support for OAuth connectors that require user input

* Cleanup

* Fix linear

* Small re-naming

* Remove console.log
2024-12-31 18:03:33 -08:00
Yuhong Sun
ccd3983802 Linear OAuth Connector (#3570) 2024-12-31 16:11:09 -08:00
pablonyx
240f3e4fff Ensure users cannot modify their roles on main
Ensure users cannot modify their roles
2024-12-31 15:59:27 -05:00
pablonyx
1291b3d930 Add anonymous user to main
Anonymous user
2024-12-31 15:58:52 -05:00
rkuo-danswer
d05f1997b5 Merge pull request #3569 from onyx-dot-app/bugfix/alt_index
we didn't want to rename the alt index suffix, reverting
2024-12-31 12:39:00 -08:00
Chris Weaver
aa2e2a62b9 Small Egnyte tweaks (#3568) 2024-12-31 19:28:38 +00:00
Richard Kuo
174e5968f8 we didn't want to rename the alt index suffix, reverting 2024-12-31 11:28:11 -08:00
pablodanswer
1f27606e17 minor clean up 2024-12-31 13:04:02 -05:00
pablodanswer
60355b84c1 quick nits 2024-12-31 13:04:02 -05:00
pablodanswer
680ab9ea30 updated logic 2024-12-31 13:04:02 -05:00
pablodanswer
c2447dbb1c cosmetic updates 2024-12-31 13:04:02 -05:00
pablodanswer
52bad522f8 update for multi-tenant clarity 2024-12-31 13:04:02 -05:00
pablodanswer
63e5e58313 add anonymous user 2024-12-31 13:04:02 -05:00
rkuo-danswer
2643782e30 Merge pull request #3567 from onyx-dot-app/bugfix/revert_vespa
Revert "More efficient Vespa indexing (#3552)"
2024-12-31 09:47:00 -08:00
Richard Kuo
3eb72e5c1d Revert "More efficient Vespa indexing (#3552)"
This reverts commit 2783216781.
2024-12-31 09:40:23 -08:00
rkuo-danswer
9b65c23a7e Merge pull request #3566 from onyx-dot-app/bugfix/primary_task_timings
re-enable celery task execution logging in primary worker
2024-12-31 01:29:05 -08:00
Richard Kuo (Danswer)
b43a8e48c6 add some return types to distinguish when the task is actually performing work 2024-12-31 00:10:33 -08:00
Richard Kuo (Danswer)
1955c1d67b re-enable celery task execution logging in primary worker 2024-12-30 21:53:00 -08:00
Chris Weaver
3f92ed9d29 Airtable connector (#3564)
* Airtable connector

* Improvements

* improve

* Clean up test

* Add secrets

* Fix mypy + add access token

* Mock unstructured call

* Adjust comments

* Fix ID in test
2024-12-31 03:06:28 +00:00
Weves
618369f4a1 Small fix 2024-12-30 19:20:30 -08:00
pablonyx
2783216781 More efficient Vespa indexing (#3552)
---------

Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com>
2024-12-30 18:51:14 -08:00
rkuo-danswer
bec0f9fb23 permission sync in cloud and beat expiry adjustment (#3544)
* try fixing exception in cloud

* raise beat expiry ... 60 seconds might be starving certain tasks completely

* adjust expiry down to 10 min

* raise concurrency overflow for indexing worker.

* parent pid check

* fix comment

* fix parent pid check, also actually raise an exception from the task if the spawned task exit status is bad

* fix pid check

* some cleanup and task wait fixes

* review fixes

* comment some code so we don't change too many things at once

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-12-31 01:05:57 +00:00
pablodanswer
97a03e7fc8 nit 2024-12-29 21:07:12 -05:00
pablodanswer
8d6e8269b7 k 2024-12-29 21:07:12 -05:00
pablodanswer
9ce2c6c517 minor change 2024-12-29 21:07:12 -05:00
pablodanswer
2ad8bdbc65 k 2024-12-29 21:07:12 -05:00
rkuo-danswer
a83c9b40d5 Bugfix/oauth fix (#3507)
* old oauth file left behind

* fix function change that was lost in merge

* fix some testing vars

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-30 01:49:12 +00:00
Chris Weaver
340fab1375 Additional error handling + logging for google drive connector (#3563)
* Additional error handling + logging for google drive connector

* Fix mypy
2024-12-29 17:48:02 -08:00
hagen-danswer
3ec338307f Fixed indexing issues with Salesforce 2024-12-29 16:45:29 -08:00
pablonyx
27acd3387a Auth specific rate limiting (#3463)
* k

* v1

* fully functional

* finalize

* nit

* nit

* nit

* clean up with wrapper + comments

* k

* update

* minor clean
2024-12-29 23:34:23 +00:00
hagen-danswer
d14ef431a7 Improve Salesforce connector 2024-12-29 14:18:40 -08:00
pablonyx
9bffeb65af Eagerly load CCpair connectors (#3531)
* remove left over vim command

* eager loading

* Revert "remove left over vim command"

This reverts commit 184a134ae0.
2024-12-29 15:58:38 +00:00
Yuhong Sun
f4806da653 Fix Null Value in PG (#3559)
* k

* k

* k

* k

* k
2024-12-29 01:53:16 +00:00
pablonyx
e2700b2bbd Remove left over yaml errors (#3527)
* remove left over vim command

* additional misconfigurations

* ensure all regions updated
2024-12-29 01:45:07 +00:00
Yuhong Sun
fc81a3fb12 Zendesk Retries (#3558)
* k

* k

* k

* k
2024-12-28 23:51:49 +00:00
pablonyx
2203cfabea Prevent SSRF risk
Prevent SSRF risk
2024-12-28 15:25:57 -05:00
pablodanswer
f4050306d6 Prevent SSRF risk 2024-12-28 15:25:12 -05:00
Weves
2d960a477f Fix discourse connector 2024-12-24 12:43:10 -08:00
hagen-danswer
8837b8ea79 Curators can now update the curator relationship (#3536)
* Curators can now update the curator relationship

* mypy

* mypy

* whoops haha
2024-12-24 18:49:58 +00:00
hagen-danswer
3dfb214f73 Slackbot polish (#3547) 2024-12-24 16:19:15 +00:00
pablonyx
18d7262608 fix logo rendering (#3542) 2024-12-22 23:00:33 +00:00
pablonyx
09b879ee73 Ensure gmail works for personal accounts (#3541)
* Ensure gmail works for personal accounts

* nit

* minor update
2024-12-22 23:00:14 +00:00
rkuo-danswer
aaa668c963 Merge pull request #3534 from onyx-dot-app/bugfix/validate_ttl
raise activity timeout to one hour
2024-12-22 13:13:57 -08:00
pablonyx
edb877f4bc fix NUL character (#3540) 2024-12-21 23:30:25 +00:00
rkuo-danswer
eb369caefb log attempt id, log elapsed since task execution start, remove log spam (#3539)
* log attempt id, log elapsed since task execution start, remove log spam

* diagnostic lock logs

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-21 23:03:50 +00:00
Chris Weaver
b9567eabd7 Fix bedrock w/ access keys (#3538)
* Fix bedrock w/ access keys

* cleanup

* Remove extra #
2024-12-21 02:24:11 +00:00
Richard Kuo (Danswer)
13bbf67091 raise activity timeout to one hour 2024-12-20 16:18:35 -08:00
hagen-danswer
457a4c73f0 Made sure confluence connector recursive by page includes top level page (#3532)
* Made sure confluence connector by page includes top level page

* surface level change
2024-12-20 21:53:59 +00:00
rkuo-danswer
ce37688b5b allow limited user to create chat session (#3533)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-20 21:36:41 +00:00
pablonyx
4e2c90f4af Proper user deletion / organization leaving (#3460)
* Proper user deletion / organization leaving

* minor nit

* update

* udpate provisioning

* minor cleanup

* typing

* post rebase
2024-12-20 21:01:03 +00:00
pablonyx
513dd8a319 update toggling states (#3519) 2024-12-20 20:27:22 +00:00
hagen-danswer
71c5043832 Added filter to exclude attachments with unsupported file extensions (#3530)
* Added filter to exclude attachments with unsupported file extensions

* extension
2024-12-20 19:48:15 +00:00
pablonyx
64b6f15e95 AWS extraneous error fix (#3529)
* remove left over vim command

* aws fix

* k

* remove double
2024-12-20 19:31:04 +00:00
hagen-danswer
35022f5f09 Fix group table (#3523) 2024-12-20 17:51:26 +00:00
hagen-danswer
0d44014c16 Cleanup PR template to make it more concise (#3524) 2024-12-20 17:49:31 +00:00
Yuhong Sun
1b9e9f48fa Update README.md 2024-12-20 10:26:57 -08:00
Yuhong Sun
05fb5aa27b Update README.md 2024-12-20 10:25:34 -08:00
Yuhong Sun
3b645b72a3 Crop Logo Closer 2024-12-20 10:23:52 -08:00
Yuhong Sun
fe770b5c3a Fix Logo On DarkMode (#3525) 2024-12-20 10:15:48 -08:00
hagen-danswer
1eaf885f50 associating credentials with connectors is not considered editing (#3522)
* associating credentials with connectors is not considered editing

* formatting

* formatting

* Update credentials.py

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-12-20 17:36:25 +00:00
rkuo-danswer
a187aa508c use redis exclusively with active signal renewal in more places to perform indexing fence validation (#3517)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-20 06:54:00 +00:00
pablonyx
aa4bfa2a78 Forgot password feature (#3437)
* forgot password feature

* improved config

* nit

* nit
2024-12-20 04:53:37 +00:00
pablonyx
9011b8a139 Update citations in shared chat display (#3487)
* update shared chat display

* Change Copy

* fix icon

* remove secret!

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-12-20 01:48:29 +00:00
pablonyx
59c774353a Latex formatting (#3499) 2024-12-19 14:48:06 -08:00
pablonyx
b458d504af Sidebar Default Open (#3488) 2024-12-19 14:04:50 -08:00
Yuhong Sun
f83e7bfcd9 Fix Default CC Pair (#3513) 2024-12-19 09:43:12 -08:00
pablonyx
4d2e26ce4b MT Cloud Tracking Fix (#3514) 2024-12-19 08:47:02 -08:00
pablonyx
817fdc1f36 Ensure metadata overrides file contents (#3512)
* ensure metadata overrides file contents

* update more blocks
2024-12-19 04:44:24 +00:00
rkuo-danswer
e9b10e8b41 temporarily disabling validate indexing fences (#3502)
* temporarily disabling validate indexing fences

* add back a few startup checks in the cloud

* use common vespa client to perform health check

* log vespa url and try using http1 on light worker index methods

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-19 01:32:09 +00:00
pablonyx
a0fa4adb60 Ensure password validation errors propagate (#3509)
* ensure password validation errors propagate

* copy update

* support o1

* improve typing

* Revert "support o1"

This reverts commit 9b7aa6008c.
2024-12-19 00:05:57 +00:00
pablonyx
ca9ba925bd Support o1 (#3510)
* support o1

* nit
2024-12-19 00:05:00 +00:00
rkuo-danswer
833cc5c97c Merge pull request #3497 from emerzon/new_icons
New model icons for LLM Picker
2024-12-18 16:38:31 -08:00
Chris Weaver
23ecf654ed Add support for custom LLM error messages (#3501)
* Add support for custom LLM error messages

* Fix mypy
2024-12-17 22:58:17 -08:00
pablonyx
ddc6a6d2b3 Wrap nits (#3496) 2024-12-17 18:03:38 -08:00
pablonyx
571c8ece32 Slack Workspace Alembic Updates
Old alembic migration + restore workspace
2024-12-17 16:28:59 -08:00
pablodanswer
884bdb4b01 old alembic migration + restore workspace 2024-12-17 16:28:05 -08:00
pablonyx
b3ecf0d59f Migrate user milestone logic (#3493) 2024-12-17 15:59:56 -08:00
Emerson Gomes
f56fda27c9 Add also Microsoft models 2024-12-17 16:37:52 -06:00
Emerson Gomes
b1e4d4ea8d Adds icons for Amazon, Meta and Mistral models (when proxied via LiteLLM) 2024-12-17 16:20:46 -06:00
pablonyx
8db6d49fe5 IAM Auth for RDS (#3479)
* k

* functional iam auth

* k

* k

* improve typing

* add deployment options

* cleanup

* quick clean up

* minor cleanup

* additional clarity for db session operations

* nit

* k

* k

* update configs

* docker compose spacing
2024-12-17 22:02:37 +00:00
pablonyx
28598694b1 Add delete all chats option (#2515)
* Add delete all chats option

* post rebase fixes

* final validation

* minor cleanup

* move up
2024-12-17 02:55:35 +00:00
Emerson Gomes
b5d0df90b9 Remove hardcoded root path for HF models 2024-12-16 19:03:15 -08:00
pablonyx
48be6338ec Update Hubpost tracking form submission (#3261)
* Update Hubpost tracking form submission

* minor cleanup

* validated

* validate

* nit

* k
2024-12-17 02:31:09 +00:00
pablonyx
ed9014f03d Use logotypes where feasible (#3478)
* Use logotypes where feasible

* quick nit

* minor cleanup
2024-12-17 02:13:45 +00:00
rkuo-danswer
2dd51230ed clear indexing fences with no celery tasks queued (#3482)
* allow beat tasks to expire. it isn't important that they all run

* validate fences are in a good state and cancel/fail them if not

* add function timings for important beat tasks

* optimize lookups, add lots of comments

* review changes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-17 00:55:58 +00:00
pablonyx
8b249cbe63 Proper display priority seeding (#3468)
* proper seeding

* k

* clean up
2024-12-17 00:19:45 +00:00
pablonyx
6b50f86cd2 Improved theming (#3204) 2024-12-16 22:24:32 +00:00
pablonyx
bd2805b6df Update llm override defaults (#3230)
* update llm override defaults

* post rebase fix
2024-12-16 22:18:21 +00:00
pablonyx
2847ab003e Prompting (#3372)
* auto generate start prompts

* post rebase clean up

* update for clarity
2024-12-16 21:34:43 +00:00
pablodanswer
1df6a506ec Revert "update pre-commit black version (#3250)"
This reverts commit d954914a0a.
2024-12-16 13:57:56 -08:00
pablonyx
f1541d1fbe Update default assistant to search for new users (#3317)
* update default assistant to search for new users

* update!
2024-12-16 21:15:33 +00:00
rkuo-danswer
dd0c4b64df errors in the summary row should be counting last_finished_status as reflected in the per connector rows (#3484)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-16 20:53:19 +00:00
pablonyx
788b3015bc fix single quote block in llm answer (#3139) 2024-12-16 20:37:47 +00:00
pablonyx
cbbf10f450 remove tenant id logs (#3063) 2024-12-16 20:24:09 +00:00
pablonyx
d954914a0a update pre-commit black version (#3250) 2024-12-16 20:04:42 +00:00
pablodanswer
bee74ac360 mark slack perm sync as flaky 2024-12-16 11:50:03 -08:00
pablonyx
29ef64272a Update chat provider values
Update chat provider values
2024-12-16 11:46:53 -08:00
pablodanswer
01bf6ee4b7 quick clean up 2024-12-16 11:43:34 -08:00
pablodanswer
0502417cbe update chat provider values 2024-12-16 11:39:25 -08:00
pablodanswer
d0483dd269 temporary vespa bump for tests 2024-12-15 21:41:21 -08:00
pablodanswer
eefa872d60 fix no space left on device for chromatic model server 2024-12-15 18:40:25 -08:00
pablonyx
3f3d4da611 do not include slackbot sessions when fetching chat sessions
do not include slackbot sessions when fetching `chat sessions`
2024-12-15 16:35:19 -08:00
pablodanswer
469068052e don't include slackbot sessions 2024-12-15 16:34:39 -08:00
pablonyx
9032b05606 Increase password requirements
Increase password requirements
2024-12-15 16:29:11 -08:00
pablodanswer
334bc6be8c Increase password requirements 2024-12-15 16:28:45 -08:00
Yuhong Sun
814f97c2c7 MT Cloud Monitoring (#3465) 2024-12-15 16:05:03 -08:00
pablodanswer
4f5a2b47c4 ensure integration tests build 2024-12-15 10:43:55 -08:00
pablodanswer
f545508268 Updated model server run-on config 2024-12-15 10:35:57 -08:00
pablonyx
590986ec65 Merge pull request #3476 from onyx-dot-app/fix_model_server_building
Update model server
2024-12-14 20:52:13 -08:00
pablodanswer
531bab5409 update model server 2024-12-14 20:51:03 -08:00
pablodanswer
29c44007c4 update model server 2024-12-14 20:49:05 -08:00
pablonyx
d388643a04 Cloud settings -> billing (#3469) 2024-12-14 18:10:50 -08:00
pablonyx
8a422683e3 Update folder logic (#3472) 2024-12-14 17:59:30 -08:00
pablonyx
ddc0230d68 align user dropdown in top right (#3473) 2024-12-14 17:25:11 -08:00
Yuhong Sun
6711e91dbf Seed Spacing (#3474) 2024-12-14 17:23:00 -08:00
pablodanswer
cff2346db5 Scale up model server 2024-12-14 17:19:28 -08:00
Yuhong Sun
8d3fad1f12 Change Default Assistant Description (#3470) 2024-12-14 17:00:08 -08:00
292 changed files with 12498 additions and 3832 deletions

View File

@@ -6,24 +6,6 @@
[Describe the tests you ran to verify your changes]
## Accepted Risk (provide if relevant)
N/A
## Related Issue(s) (provide if relevant)
N/A
## Mental Checklist:
- All of the automated tests pass
- All PR comments are addressed and marked resolved
- If there are migrations, they have been rebased to latest main
- If there are new dependencies, they are added to the requirements
- If there are new environment variables, they are added to all of the deployment methods
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- Docker images build and basic functionalities work
- Author has done a final read through of the PR right before merge
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)

View File

@@ -66,6 +66,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -8,18 +8,29 @@ on:
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
DOCKER_BUILDKIT: 1
BUILDKIT_PROGRESS: plain
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
build-amd64:
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: System Info
run: |
df -h
free -h
docker system prune -af --volumes
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Login to Docker Hub
uses: docker/login-action@v3
@@ -27,24 +38,80 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
- name: Build and Push AMD64
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
platforms: linux/amd64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64
build-args: |
ONYX_VERSION=${{ github.ref_name }}
DANSWER_VERSION=${{ github.ref_name }}
outputs: type=registry
provenance: false
build-arm64:
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: System Info
run: |
df -h
free -h
docker system prune -af --volumes
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and Push ARM64
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
push: true
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
outputs: type=registry
provenance: false
merge-and-scan:
needs: [build-amd64, build-arm64]
runs-on: ubuntu-latest
steps:
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create and Push Multi-arch Manifest
run: |
docker buildx create --use
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
@@ -53,3 +120,4 @@ jobs:
with:
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout: "10m"

View File

@@ -15,7 +15,12 @@ jobs:
# See https://runs-on.com/runners/linux/
runs-on:
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -196,7 +201,12 @@ jobs:
needs: playwright-tests
runs-on:
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
steps:
- name: Checkout code
uses: actions/checkout@v4

View File

@@ -20,8 +20,7 @@ env:
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on:
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4

View File

@@ -26,7 +26,15 @@ env:
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
# Salesforce
SF_USERNAME: ${{ secrets.SF_USERNAME }}
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -3,7 +3,7 @@
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
</h2>
<p align="center">
@@ -13,7 +13,7 @@
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -24,7 +24,7 @@
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
@@ -133,15 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)
## ✨Contributors
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

1
backend/.gitignore vendored
View File

@@ -9,3 +9,4 @@ api_keys.py
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*
onyx/connectors/salesforce/data/

View File

@@ -1,39 +1,49 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
from sqlalchemy.engine.base import Connection
from typing import Literal
import os
import ssl
import asyncio
from logging.config import fileConfig
import logging
from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from shared_configs.configs import MULTI_TENANT
from onyx.db.engine import build_connection_string
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine import get_all_tenant_ids
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Alembic Config object
config = context.config
# Interpret the config file for Python logging.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)
# Add your model's MetaData object here for 'autogenerate' support
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
# Set up logging
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
if USE_IAM_AUTH:
if not os.path.exists(SSL_CERT_FILE):
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
def include_object(
object: SchemaItem,
@@ -49,20 +59,12 @@ def include_object(
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
"""
Determines whether a database object should be included in migrations.
Excludes specified tables from migrations.
"""
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def get_schema_options() -> tuple[str, bool, bool]:
"""
Parses command-line options passed via '-x' in Alembic commands.
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
"""
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -90,16 +92,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
"""
Executes migrations in the specified schema.
"""
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))
context.configure(
@@ -117,11 +115,25 @@ def do_run_migrations(
context.run_migrations()
def provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
if USE_IAM_AUTH:
# Database connection settings
region = AWS_REGION_NAME
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
# Get IAM authentication token
token = get_iam_auth_token(host, port, user, region)
# For Alembic / SQLAlchemy in this context, set SSL and password
cparams["password"] = token
cparams["ssl"] = ssl_context
async def run_async_migrations() -> None:
"""
Determines whether to run migrations for a single schema or all schemas,
and executes migrations accordingly.
"""
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
@@ -129,10 +141,16 @@ async def run_async_migrations() -> None:
poolclass=pool.NullPool,
)
if upgrade_all_tenants:
# Run migrations for all tenant schemas sequentially
tenant_schemas = get_all_tenant_ids()
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
for schema in tenant_schemas:
try:
logger.info(f"Migrating schema: {schema}")
@@ -162,15 +180,20 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
"""
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
# Run offline migrations for all tenant schemas
engine = create_async_engine(url)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
@@ -207,9 +230,6 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
"""
Runs migrations in 'online' mode using an asynchronous engine.
"""
asyncio.run(run_async_migrations())

View File

@@ -0,0 +1,129 @@
from alembic import op
import sqlalchemy as sa
import datetime
# revision identifiers, used by Alembic.
revision = "25d86cbfce78"
down_revision = "c0aab6edb6dd"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create user_folder table with additional 'display_priority' field
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"parent_id", sa.Integer(), sa.ForeignKey("user_folder.id"), nullable=True
),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column("created_at", sa.DateTime(), default=datetime.datetime.utcnow),
)
# Migrate data from chat_folder to user_folder
op.execute(
"""
INSERT INTO user_folder (id, user_id, name, display_priority, created_at)
SELECT id, user_id, name, display_priority, CURRENT_TIMESTAMP FROM chat_folder
"""
)
# Update chat_session table to reference user_folder instead of chat_folder
op.drop_constraint(
"chat_session_chat_folder_fk", "chat_session", type_="foreignkey"
)
op.alter_column(
"chat_session",
"folder_id",
existing_type=sa.Integer(),
nullable=True,
existing_nullable=True,
existing_server_default=None,
)
op.create_foreign_key(
"fk_chat_session_folder_id_user_folder",
"chat_session",
"user_folder",
["folder_id"],
["id"],
ondelete="SET NULL",
)
# Drop the chat_folder table
op.drop_table("chat_folder")
# Create user_file table
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"parent_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
)
def downgrade() -> None:
# Recreate chat_folder table
op.create_table(
"chat_folder",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column(
"user_id",
sa.UUID(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
)
# Migrate data back from user_folder to chat_folder
op.execute(
"""
INSERT INTO chat_folder (id, user_id, name, display_priority)
SELECT id, user_id, name, display_priority FROM user_folder
WHERE id IN (SELECT DISTINCT folder_id FROM chat_session WHERE folder_id IS NOT NULL)
"""
)
# Update chat_session table to reference chat_folder again
op.drop_constraint(
"fk_chat_session_folder_id_user_folder", "chat_session", type_="foreignkey"
)
op.alter_column(
"chat_session",
"folder_id",
existing_type=sa.Integer(),
nullable=True,
existing_nullable=True,
existing_server_default=None,
)
op.create_foreign_key(
"chat_session_chat_folder_fk",
"chat_session",
"chat_folder",
["folder_id"],
["id"],
ondelete="SET NULL",
)
# Drop the user_file table
op.drop_table("user_file")
# Drop the user_folder table
op.drop_table("user_folder")

View File

@@ -0,0 +1,121 @@
"""properly_cascade
Revision ID: 35e518e0ddf4
Revises: 91a0a4d62b14
Create Date: 2024-09-20 21:24:04.891018
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "35e518e0ddf4"
down_revision = "91a0a4d62b14"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Update chat_message foreign key constraint
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["id"],
ondelete="CASCADE",
)
# Update chat_message__search_doc foreign key constraints
op.drop_constraint(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"chat_message__search_doc_search_doc_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.create_foreign_key(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
"chat_message",
["chat_message_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"chat_message__search_doc_search_doc_id_fkey",
"chat_message__search_doc",
"search_doc",
["search_doc_id"],
["id"],
ondelete="CASCADE",
)
# Add CASCADE delete for tool_call foreign key
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.create_foreign_key(
"tool_call_message_id_fkey",
"tool_call",
"chat_message",
["message_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Revert chat_message foreign key constraint
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["id"],
)
# Revert chat_message__search_doc foreign key constraints
op.drop_constraint(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"chat_message__search_doc_search_doc_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.create_foreign_key(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
"chat_message",
["chat_message_id"],
["id"],
)
op.create_foreign_key(
"chat_message__search_doc_search_doc_id_fkey",
"chat_message__search_doc",
"search_doc",
["search_doc_id"],
["id"],
)
# Revert tool_call foreign key constraint
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.create_foreign_key(
"tool_call_message_id_fkey",
"tool_call",
"chat_message",
["message_id"],
["id"],
)

View File

@@ -0,0 +1,45 @@
"""Milestone
Revision ID: 91a0a4d62b14
Revises: dab04867cd88
Create Date: 2024-12-13 19:03:30.947551
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "91a0a4d62b14"
down_revision = "dab04867cd88"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"milestone",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=True),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("event_type", sa.String(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
)
def downgrade() -> None:
op.drop_table("milestone")

View File

@@ -0,0 +1,87 @@
"""delete workspace
Revision ID: c0aab6edb6dd
Revises: 35e518e0ddf4
Create Date: 2024-12-17 14:37:07.660631
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "c0aab6edb6dd"
down_revision = "35e518e0ddf4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = connector_specific_config - 'workspace'
WHERE source = 'SLACK'
"""
)
def downgrade() -> None:
import json
from sqlalchemy import text
from slack_sdk import WebClient
conn = op.get_bind()
# Fetch all Slack credentials
creds_result = conn.execute(
text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'")
)
all_slack_creds = creds_result.fetchall()
if not all_slack_creds:
return
for cred_row in all_slack_creds:
credential_id, credential_json = cred_row
credential_json = (
credential_json.tobytes().decode("utf-8")
if isinstance(credential_json, memoryview)
else credential_json.decode("utf-8")
)
credential_data = json.loads(credential_json)
slack_bot_token = credential_data.get("slack_bot_token")
if not slack_bot_token:
print(
f"No slack_bot_token found for credential {credential_id}. "
"Your Slack connector will not function until you upgrade and provide a valid token."
)
continue
client = WebClient(token=slack_bot_token)
try:
auth_response = client.auth_test()
workspace = auth_response["url"].split("//")[1].split(".")[0]
# Update only the connectors linked to this credential
# (and which are Slack connectors).
op.execute(
f"""
UPDATE connector AS c
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{{workspace}}',
to_jsonb('{workspace}'::text)
)
FROM connector_credential_pair AS ccp
WHERE ccp.connector_id = c.id
AND c.source = 'SLACK'
AND ccp.credential_id = {credential_id}
"""
)
except Exception:
print(
f"We were unable to get the workspace url for your Slack Connector with id {credential_id}."
)
print("This connector will no longer work until you upgrade.")
continue

View File

@@ -47,3 +47,11 @@ OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", ""
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")

View File

@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
)
def validate_user_creation_permissions(
def validate_object_creation_for_user(
db_session: Session,
user: User | None,
target_group_ids: list[int] | None = None,
@@ -440,32 +440,108 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
_validate_curator_status__no_commit(db_session, [user])
def update_user_curator_relationship(
def _validate_curator_relationship_update_requester(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
user_making_change: User | None = None,
) -> None:
user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
"""
This function validates that the user making the change has the necessary permissions
to update the curator relationship for the target user in the given user group.
"""
if user.role == UserRole.ADMIN:
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
return
# check if the user making the change is a curator in the group they are changing the curator relationship for
user_making_change_curator_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user_making_change.id,
# only check if the user making the change is a curator if they are a curator
# otherwise, they are a global_curator and can update the curator relationship
# for any group they are a member of
only_curator_groups=user_making_change.role == UserRole.CURATOR,
)
requestor_curator_group_ids = [
group.id for group in user_making_change_curator_groups
]
if user_group_id not in requestor_curator_group_ids:
raise ValueError(
f"User '{user.email}' is an admin and therefore has all permissions "
f"user making change {user_making_change.email} is not a curator,"
f" admin, or global_curator for group '{user_group_id}'"
)
def _validate_curator_relationship_update_request(
db_session: Session,
user_group_id: int,
target_user: User,
) -> None:
"""
This function validates that the curator_relationship_update request itself is valid.
"""
if target_user.role == UserRole.ADMIN:
raise ValueError(
f"User '{target_user.email}' is an admin and therefore has all permissions "
"of a curator. If you'd like this user to only have curator permissions, "
"you must update their role to BASIC then assign them to be CURATOR in the "
"appropriate groups."
)
elif target_user.role == UserRole.GLOBAL_CURATOR:
raise ValueError(
f"User '{target_user.email}' is a global_curator and therefore has all "
"permissions of a curator for all groups. If you'd like this user to only "
"have curator permissions for a specific group, you must update their role "
"to BASIC then assign them to be CURATOR in the appropriate groups."
)
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
raise ValueError(
f"This endpoint can only be used to update the curator relationship for "
"users with the CURATOR or BASIC role. \n"
f"Target user: {target_user.email} \n"
f"Target user role: {target_user.role} \n"
)
# check if the target user is in the group they are changing the curator relationship for
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=set_curator_request.user_id,
user_id=target_user.id,
only_curator_groups=False,
)
group_ids = [group.id for group in requested_user_groups]
if user_group_id not in group_ids:
raise ValueError(f"user is not in group '{user_group_id}'")
raise ValueError(
f"target user {target_user.email} is not in group '{user_group_id}'"
)
def update_user_curator_relationship(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
user_making_change: User | None = None,
) -> None:
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not target_user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
_validate_curator_relationship_update_request(
db_session=db_session,
user_group_id=user_group_id,
target_user=target_user,
)
_validate_curator_relationship_update_requester(
db_session=db_session,
user_group_id=user_group_id,
user_making_change=user_making_change,
)
logger.info(
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
f"updating the curator relationship for user={target_user.email} "
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
)
relationship_to_update = (
db_session.query(User__UserGroup)
@@ -486,7 +562,7 @@ def update_user_curator_relationship(
)
db_session.add(relationship_to_update)
_validate_curator_status__no_commit(db_session, [user])
_validate_curator_status__no_commit(db_session, [target_user])
db_session.commit()

View File

@@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base
from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
@@ -62,7 +63,7 @@ def get_application() -> FastAPI:
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
oauth_client,
@@ -74,19 +75,17 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
@@ -97,19 +96,21 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
),
prefix="/auth/oidc",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
elif AUTH_TYPE == AuthType.SAML:
include_router_with_global_prefix_prepended(application, saml_router)
include_auth_router_with_prefix(
application,
saml_router,
prefix="/auth/saml",
)
# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)

View File

@@ -1,5 +1,7 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
@@ -10,11 +12,29 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
@@ -62,14 +82,7 @@ class SlackOAuth:
@classmethod
def generate_oauth_url(cls, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.REDIRECT_URI}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
@@ -77,10 +90,14 @@ class SlackOAuth:
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
@@ -102,82 +119,151 @@ class SlackOAuth:
return session
# Work in progress
# class ConfluenceCloudOAuth:
# """work in progress"""
class ConfluenceCloudOAuth:
"""work in progress"""
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
# class OAuthSession(BaseModel):
# """Stored in redis to be looked up on callback"""
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
# email: str
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
# CONFLUENCE_OAUTH_SCOPE = (
# "read:confluence-props%20"
# "read:confluence-content.all%20"
# "read:confluence-content.summary%20"
# "read:confluence-content.permission%20"
# "read:confluence-user%20"
# "read:confluence-groups%20"
# "readonly:content.attachment:confluence"
# )
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# # eventually for Confluence Data Center
# # oauth_url = (
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# # f"&redirect_uri={redirectme_uri}"
# # )
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
# @classmethod
# def generate_oauth_url(cls, state: str) -> str:
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
# @classmethod
# def generate_dev_oauth_url(cls, state: str) -> str:
# """dev mode workaround for localhost testing
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
# """
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
# @classmethod
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# url = (
# "https://auth.atlassian.com/authorize"
# f"?audience=api.atlassian.com"
# f"&client_id={cls.CLIENT_ID}"
# f"&redirect_uri={redirect_uri}"
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
# f"&state={state}"
# "&response_type=code"
# "&prompt=consent"
# )
# return url
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
# @classmethod
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
# """Temporary state to store in redis. to be looked up on auth response.
# Returns a json string.
# """
# session = ConfluenceCloudOAuth.OAuthSession(
# email=email, redirect_on_success=redirect_on_success
# )
# return session.model_dump_json()
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
# @classmethod
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
# return session
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
@@ -192,8 +278,11 @@ def prepare_authorization_request(
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
@@ -203,6 +292,11 @@ def prepare_authorization_request(
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
@@ -210,8 +304,6 @@ def prepare_authorization_request(
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
# elif connector == DocumentSource.GOOGLE_DRIVE:
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
@@ -223,6 +315,7 @@ def prepare_authorization_request(
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
@@ -421,3 +514,116 @@ def handle_slack_oauth_callback(
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -3,6 +3,7 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
@@ -12,15 +13,23 @@ from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_jwt_strategy
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
@@ -114,3 +123,48 @@ async def impersonate_user(
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -39,3 +39,8 @@ class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
referral_source: str | None = None
class TenantDeletionPayload(BaseModel):
tenant_id: str
email: str

View File

@@ -3,15 +3,19 @@ import logging
import uuid
import aiohttp # Async HTTP client
import httpx
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import drop_schema
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
@@ -20,6 +24,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
@@ -35,22 +40,27 @@ from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
async def get_or_create_tenant_id(
email: str, referral_source: str | None = None
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
try:
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
@@ -122,6 +132,17 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
@@ -165,6 +186,7 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
@@ -267,3 +289,59 @@ def configure_default_api_keys(db_session: Session) -> None:
logger.info(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)
async def submit_to_hubspot(
email: str, referral_source: str | None, request: Request
) -> None:
if not HUBSPOT_TRACKING_URL:
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
return
# HubSpot tracking cookie
hubspot_cookie = request.cookies.get("hubspotutk")
# IP address
ip_address = request.client.host if request.client else None
data = {
"fields": [
{"name": "email", "value": email},
{"name": "referral_source", "value": referral_source or ""},
],
"context": {
"hutk": hubspot_cookie,
"ipAddress": ip_address,
"pageUri": str(request.url),
"pageName": "User Registration",
},
}
async with httpx.AsyncClient() as client:
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
if response.status_code != 200:
logger.error(f"Failed to submit to HubSpot: {response.text}")
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)

View File

@@ -68,3 +68,11 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()

View File

@@ -83,7 +83,7 @@ def patch_user_group(
def set_user_curator(
user_group_id: int,
set_curator_request: SetCuratorRequest,
_: User | None = Depends(current_admin_user),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
@@ -91,6 +91,7 @@ def set_user_curator(
db_session=db_session,
user_group_id=user_group_id,
set_curator_request=set_curator_request,
user_making_change=user,
)
except ValueError as e:
logger.error(f"Error setting user curator: {e}")

View File

@@ -0,0 +1,34 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)
def event_telemetry(
distinct_id: str, event: str, properties: dict | None = None
) -> None:
"""Capture and send an event to PostHog, flushing immediately."""
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
try:
posthog.capture(distinct_id, event, properties)
posthog.flush()
except Exception as e:
logger.error(f"Error capturing PostHog event: {e}")

View File

@@ -1,5 +1,6 @@
import asyncio
import json
import time
from types import TracebackType
from typing import cast
from typing import Optional
@@ -320,8 +321,6 @@ async def embed_text(
api_url: str | None,
api_version: str | None,
) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
if not all(texts):
logger.error("Empty strings provided for embedding")
raise ValueError("Empty strings are not allowed for embedding.")
@@ -330,8 +329,17 @@ async def embed_text(
logger.error("No texts provided for embedding")
raise ValueError("No texts provided for embedding.")
start = time.monotonic()
total_chars = 0
for text in texts:
total_chars += len(text)
if provider_type is not None:
logger.debug(f"Using cloud provider {provider_type} for embedding")
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
)
if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
@@ -363,8 +371,16 @@ async def embed_text(
logger.error(error_message)
raise ValueError(error_message)
elapsed = time.monotonic() - start
logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with provider {provider_type} in {elapsed:.2f}"
)
elif model_name is not None:
logger.debug(f"Using local model {model_name} for embedding")
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
)
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
@@ -382,13 +398,17 @@ async def embed_text(
for embedding in embeddings_vectors
]
elapsed = time.monotonic() - start
logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with local model {model_name} in {elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
logger.info(f"Successfully embedded {len(texts)} texts")
return embeddings
@@ -440,7 +460,8 @@ async def process_embed_request(
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
elif not all(embed_request.texts):
if not all(embed_request.texts):
raise ValueError("Empty strings are not allowed for embedding.")
try:
@@ -471,9 +492,12 @@ async def process_embed_request(
detail=str(e),
)
except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
logger.exception(
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
)
raise HTTPException(
status_code=500, detail=f"Error during embedding process: {e}"
)
@router.post("/cross-encoder-scores")

View File

@@ -27,8 +27,8 @@ from shared_configs.configs import SENTRY_DSN
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_CACHE_PATH = Path("/root/.cache/huggingface/")
TEMP_HF_CACHE_PATH = Path("/root/.cache/temp_huggingface/")
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
transformer_logging.set_verbosity_error()
@@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
the files in the existing huggingface cache that don't exist in the temp
huggingface cache.
"""
for item in source.iterdir():
target_path = dest / item.relative_to(source)
if item.is_dir():

View File

@@ -0,0 +1,80 @@
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from textwrap import dedent
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
from onyx.configs.app_configs import SMTP_PASS
from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.models import User
def send_email(
user_email: str,
subject: str,
body: str,
mail_from: str = EMAIL_FROM,
) -> None:
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
msg = MIMEMultipart()
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg.attach(MIMEText(body))
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
except Exception as e:
raise e
def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Workspace"
body = dedent(
f"""\
Hello,
You have been invited to join a workspace on Onyx.
To join the workspace, please visit the following link:
{WEB_DOMAIN}/auth/login
Best regards,
The Onyx Team
"""
)
send_email(user_email, subject, body, current_user.email)
def send_forgot_password_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
body = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, body, mail_from)
def send_user_verification_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
subject = "Onyx Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
body = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, body, mail_from)

View File

@@ -4,6 +4,8 @@ from typing import cast
from onyx.auth.schemas import UserRole
from onyx.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from onyx.configs.constants import NO_AUTH_USER_EMAIL
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.key_value_store.store import KeyValueStore
from onyx.key_value_store.store import KvKeyNotFoundError
from onyx.server.manage.models import UserInfo
@@ -28,13 +30,16 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
def fetch_no_auth_user(
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@onyx.app",
id=NO_AUTH_USER_ID,
email=NO_AUTH_USER_EMAIL,
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.ADMIN,
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
is_anonymous_user=anonymous_user_enabled,
)

View File

@@ -49,4 +49,7 @@ class UserCreate(schemas.BaseUserCreate):
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
"""
Role updates are not allowed through the user update endpoint for security reasons
Role changes should be handled through a separate, admin-only process
"""

View File

@@ -1,10 +1,8 @@
import smtplib
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
@@ -52,19 +50,17 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_FROM
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import SMTP_PASS
from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
@@ -72,6 +68,9 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
from onyx.db.auth import get_access_token_db
@@ -86,8 +85,9 @@ from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.server.utils import BasicAuthenticationError
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -99,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -139,6 +144,20 @@ def user_needs_to_be_verified() -> bool:
return False
def anonymous_user_enabled() -> bool:
if MULTI_TENANT:
return False
redis_client = get_redis_client(tenant_id=None)
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is None:
return False
assert isinstance(value, bytes)
return int(value.decode("utf-8")) == 1
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if not whitelist:
@@ -189,30 +208,6 @@ def verify_email_domain(email: str) -> None:
)
def send_user_verification_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
msg = MIMEMultipart()
msg["Subject"] = "Onyx Email Verification"
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
body = MIMEText(f"Click the following link to verify your email address: {link}")
msg.attach(body)
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
@@ -225,17 +220,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
# We verify the password here to make sure it's valid before we proceed
await self.validate_password(
user_create.password, cast(schemas.UC, user_create)
)
user_count: int | None = None
referral_source = (
request.cookies.get("referral_source", None)
if request is not None
else None
)
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_create_tenant_id",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
request=request,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -268,7 +272,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdate(
password=user_create.password,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
@@ -278,7 +281,37 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
return user
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
# Validate password according to basic security guidelines
if len(password) < 12:
raise exceptions.InvalidPasswordException(
reason="Password must be at least 12 characters long."
)
if len(password) > 64:
raise exceptions.InvalidPasswordException(
reason="Password must not exceed 64 characters."
)
if not any(char.isupper() for char in password):
raise exceptions.InvalidPasswordException(
reason="Password must contain at least one uppercase letter."
)
if not any(char.islower() for char in password):
raise exceptions.InvalidPasswordException(
reason="Password must contain at least one lowercase letter."
)
if not any(char.isdigit() for char in password):
raise exceptions.InvalidPasswordException(
reason="Password must contain at least one number."
)
if not any(char in PASSWORD_SPECIAL_CHARS for char in password):
raise exceptions.InvalidPasswordException(
reason="Password must contain at least one special character from the following set: "
f"{PASSWORD_SPECIAL_CHARS}."
)
return
async def oauth_callback(
self,
@@ -293,17 +326,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
referral_source = (
getattr(request.state, "referral_source", None) if request else None
)
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_create_tenant_id",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
request=request,
)
if not tenant_id:
@@ -365,6 +399,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
@@ -418,6 +453,39 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user.email,
request=request,
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
user_count = await get_user_count()
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if user_count == 1:
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=MilestoneRecordType.USER_SIGNED_UP,
properties=None,
db_session=db_session,
)
else:
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=MilestoneRecordType.MULTIPLE_USERS,
properties=None,
db_session=db_session,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.notice(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
@@ -428,7 +496,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
if not EMAIL_CONFIGURED:
logger.error(
"Email is not configured. Please configure email in the admin panel"
)
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Your admin has not enbaled this feature.",
)
send_forgot_password_email(user.email, token)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
@@ -449,7 +525,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_create_tenant_id",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
@@ -510,7 +586,7 @@ class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_create_tenant_id",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user.email,
@@ -546,9 +622,7 @@ def get_database_strategy(
auth_backend = AuthenticationBackend(
name="jwt" if MULTI_TENANT else "database",
transport=cookie_transport,
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
) # type: ignore
@@ -635,30 +709,36 @@ async def double_check_user(
user: User | None,
optional: bool = DISABLE_AUTH,
include_expired: bool = False,
allow_anonymous_access: bool = False,
) -> User | None:
if optional:
return user
if user is not None:
# If user attempted to authenticate, verify them, do not default
# to anonymous access if it fails.
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
detail="Access denied. User is not verified.",
)
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
return user
if allow_anonymous_access:
return None
if user is None:
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
detail="Access denied. User is not verified.",
)
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
return user
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated.",
)
async def current_user_with_expired_token(
@@ -673,6 +753,14 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(
user, allow_anonymous_access=anonymous_user_enabled()
)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:

View File

@@ -3,11 +3,12 @@ import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
from celery.exceptions import WorkerShutdown
from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
@@ -21,6 +22,7 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
@@ -34,8 +36,11 @@ from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -56,8 +61,8 @@ def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**kwds: Any,
) -> None:
pass
@@ -257,7 +262,8 @@ def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
client = get_vespa_http_client()
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
@@ -346,26 +352,36 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
loglevel: int,
logfile: str | None,
format: str,
colorize: bool,
**kwargs: Any,
) -> None:
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats the root logger
root_logger = logging.getLogger()
root_logger.handlers = []
root_handler = logging.StreamHandler() # Set up a handler for the root logger
# Define the log format
log_format = (
"%(levelname)-8s %(asctime)s %(filename)15s:%(lineno)-4d: %(name)s %(message)s"
)
# Set up the root handler
root_handler = logging.StreamHandler()
root_formatter = ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
log_format,
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_handler.setFormatter(root_formatter)
root_logger.addHandler(root_handler) # Apply the handler to the root logger
root_logger.addHandler(root_handler)
if logfile:
root_file_handler = logging.FileHandler(logfile)
root_file_formatter = PlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
log_format,
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_file_handler.setFormatter(root_file_formatter)
@@ -373,19 +389,23 @@ def on_setup_logging(
root_logger.setLevel(loglevel)
# reformats celery's task logger
# Configure the task logger
task_logger.handlers = []
task_handler = logging.StreamHandler()
task_handler.addFilter(TenantContextFilter())
task_formatter = CeleryTaskColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
log_format,
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_handler = logging.StreamHandler() # Set up a handler for the task logger
task_handler.setFormatter(task_formatter)
task_logger.addHandler(task_handler) # Apply the handler to the task logger
task_logger.addHandler(task_handler)
if logfile:
task_file_handler = logging.FileHandler(logfile)
task_file_handler.addFilter(TenantContextFilter())
task_file_formatter = CeleryTaskPlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
log_format,
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_file_handler.setFormatter(task_file_formatter)
@@ -398,6 +418,61 @@ def on_setup_logging(
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
strategy.logger.setLevel(logging.WARNING)
# hide celery task succeeded/failed spam
# uncomment this to hide celery task succeeded/failed spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
trace.logger.setLevel(logging.WARNING)
def set_task_finished_log_level(logLevel: int) -> None:
"""call this to override the setLevel in on_setup_logging. We are interested
in the task timings in the cloud but it can be spammy for self hosted."""
trace.logger.setLevel(logLevel)
class TenantContextFilter(logging.Filter):
"""Logging filter to inject tenant ID into the logger's name."""
def filter(self, record: logging.LogRecord) -> bool:
if not MULTI_TENANT:
record.name = ""
return True
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id:
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
record.name = f"[t:{tenant_id}]"
else:
record.name = ""
return True
@task_prerun.connect
def set_tenant_id(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
"""Signal handler to set tenant ID in context var before task starts."""
tenant_id = (
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if kwargs
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@task_postrun.connect
def reset_tenant_id(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
"""Signal handler to reset tenant ID in context var after task ends."""
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)

View File

@@ -13,7 +13,6 @@ from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -44,18 +43,18 @@ class DynamicTenantScheduler(PersistentScheduler):
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating tenant task update")
logger.info("Reload interval reached, initiating task update")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Tenant task update completed, reset reload timer")
logger.info("Task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting tenant task update process")
logger.info("Starting task update process")
try:
logger.info("Fetching all tenant IDs")
logger.info("Fetching all IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} tenants")
logger.info(f"Found {len(tenant_ids)} IDs")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
@@ -70,7 +69,7 @@ class DynamicTenantScheduler(PersistentScheduler):
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
for tenant_id in tenant_ids:
if (
@@ -83,7 +82,7 @@ class DynamicTenantScheduler(PersistentScheduler):
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
logger.info(f"Processing new item: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
@@ -129,11 +128,10 @@ class DynamicTenantScheduler(PersistentScheduler):
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
@@ -155,10 +153,6 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)

View File

@@ -61,13 +61,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -60,15 +60,21 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
# rkuo: been seeing transient connection exceptions here, so upping the connection count
# from just concurrency/concurrency to concurrency/concurrency*2
SqlEngine.init_engine(
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -60,13 +60,15 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -1,3 +1,4 @@
import logging
import multiprocessing
from typing import Any
from typing import cast
@@ -84,14 +85,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
@@ -194,6 +195,10 @@ def on_setup_logging(
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
# this can be spammy, so just enable it in the cloud for now
if MULTI_TENANT:
app_base.set_task_finished_log_level(logging.INFO)
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.

View File

@@ -1,12 +1,56 @@
# These are helper objects for tracking the keys we need to write in redis
import json
from typing import Any
from typing import cast
from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.constants import OnyxCeleryPriority
def celery_get_unacked_length(r: Redis) -> int:
"""Checking the unacked queue is useful because a non-zero length tells us there
may be prefetched tasks.
There can be other tasks in here besides indexing tasks, so this is mostly useful
just to see if the task count is non zero.
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
"""
length = cast(int, r.hlen("unacked"))
return length
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queue are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()
for _, v in r.hscan_iter("unacked"):
v_bytes = cast(bytes, v)
v_str = v_bytes.decode("utf-8")
task = json.loads(v_str)
task_description = task[0]
task_queue = task[2]
if task_queue != queue:
continue
task_id = task_description.get("headers", {}).get("id")
if not task_id:
continue
# if the queue matches and we see the task_id, add it
tasks.add(task_id)
return tasks
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
@@ -23,3 +67,96 @@ def celery_get_queue_length(queue: str, r: Redis) -> int:
total_length += cast(int, length)
return total_length
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
"""This is a redis specific way to find a task for a particular queue in redis.
It is priority aware and knows how to look through the multiple redis lists
used to implement task prioritization.
This operation is not atomic.
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
Returns true if the id is in the queue, False if not.
"""
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
if task_dict.get("headers", {}).get("id") == task_id:
return True
return False
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if workers:
for worker_name in list(workers.keys()):
# if the name filter not set, return all worker names
if not name_filter:
worker_names.append(worker_name)
continue
# if the name filter is set, return only worker names that contain the name filter
if name_filter not in worker_name:
continue
worker_names.append(worker_name)
return worker_names
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of reserved tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
reserved_task_ids: set[str] = set()
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
if reserved_tasks:
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_task_ids.add(task["id"])
return reserved_task_ids
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of active tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
active_task_ids: set[str] = set()
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
if active_tasks:
for _, task_list in active_tasks.items():
for task in task_list:
active_task_ids.add(task["id"])
return active_task_ids

View File

@@ -16,6 +16,14 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Indexing worker specific ... this lets us track the transition to STARTED in redis
# We don't currently rely on this but it has the potential to be useful and
# indexing tasks are not high volume
# we don't turn this on yet because celery occasionally runs tasks more than once
# which means a duplicate run might change the task state unexpectedly
# task_track_started = True
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -4,55 +4,86 @@ from typing import Any
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
# choosing 15 minutes because it roughly gives us enough time to process many tasks
# we might be able to reduce this greatly if we can run a unified
# loop across all tenants rather than tasks per tenant
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-prune",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "kombu-message-cleanup",
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {"priority": OnyxCeleryPriority.LOWEST},
"options": {
"priority": OnyxCeleryPriority.LOWEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]

View File

@@ -34,7 +34,9 @@ class TaskDependencyError(RuntimeError):
trail=False,
bind=True,
)
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -45,7 +47,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
# collect cc_pair_ids
cc_pair_ids: list[int] = []
@@ -76,11 +78,13 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
task_logger.exception("Unexpected exception during connector deletion check")
finally:
if lock_beat.owned():
lock_beat.release()
return True
def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
@@ -131,14 +135,14 @@ def try_generate_document_cc_pair_cleanup_tasks(
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
if redis_connector.prune.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
@@ -175,7 +179,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)

View File

@@ -1,6 +1,8 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from time import sleep
from uuid import uuid4
from celery import Celery
@@ -18,6 +20,7 @@ from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
@@ -88,10 +91,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -99,7 +102,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
@@ -128,6 +131,8 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_permissions_sync_task(
app: Celery,
@@ -219,6 +224,43 @@ def connector_permission_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
)
if not redis_connector.permissions.fenced: # The fence must exist
raise ValueError(
f"connector_permission_sync_generator_task - fence not found: "
f"fence={redis_connector.permissions.fence_key}"
)
payload = redis_connector.permissions.payload # The payload must exist
if not payload:
raise ValueError(
"connector_permission_sync_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_permission_sync_generator_task - Waiting for fence: "
f"fence={redis_connector.permissions.fence_key}"
)
sleep(1)
continue
logger.info(
f"connector_permission_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.permissions.fence_key}"
)
break
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
@@ -254,8 +296,11 @@ def connector_permission_sync_generator_task(
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
new_payload = RedisConnectorPermissionSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.permissions.set_fence(new_payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)

View File

@@ -94,10 +94,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -105,7 +105,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -149,6 +149,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_external_group_sync_task(
app: Celery,
@@ -162,7 +164,7 @@ def try_creating_external_group_sync_task(
LOCK_TIMEOUT = 30
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)

View File

@@ -1,7 +1,11 @@
import os
import sys
import time
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
import redis
import sentry_sdk
@@ -15,10 +19,13 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
@@ -26,6 +33,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -66,14 +74,18 @@ logger = setup_logger()
class IndexingCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
@@ -84,25 +96,68 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
# now = time.monotonic()
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
# try:
# # this is unintuitive, but it checks if the parent pid is still running
# os.kill(self.parent_pid, 0)
# except Exception:
# logger.exception("IndexingCallback - parent pid check exceptioned")
# raise
# self.last_parent_check = now
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned. "
f"IndexingCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
# diagnostic logging for lock errors
name = self.redis_lock.name
ttl = self.redis_client.ttl(name)
locked = self.redis_lock.locked()
owned = self.redis_lock.owned()
local_token: str | None = self.redis_lock.local.token # type: ignore
remote_token_raw = self.redis_client.get(self.redis_lock.name)
if remote_token_raw:
remote_token_bytes = cast(bytes, remote_token_raw)
remote_token = remote_token_bytes.decode("utf-8")
else:
remote_token = None
logger.warning(
f"IndexingCallback - lock diagnostics: "
f"name={name} "
f"locked={locked} "
f"owned={owned} "
f"local_token={local_token} "
f"remote_token={remote_token} "
f"ttl={ttl}"
)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
@@ -162,11 +217,19 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
redis_client = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -271,7 +334,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
search_settings_instance,
reindex,
db_session,
r,
redis_client,
tenant_id,
)
if attempt_id:
@@ -286,7 +349,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
@@ -304,12 +369,27 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
# we want to run this less frequently than the overall task
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
# clear any indexing fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_indexing_fences(
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
)
except Exception:
task_logger.exception("Exception while validating indexing fences")
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
task_logger.exception("Unexpected exception during indexing check")
finally:
if locked:
if lock_beat.owned():
@@ -320,9 +400,157 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"tenant={tenant_id}"
)
time_elapsed = time.monotonic() - time_start
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created
def validate_indexing_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
r_celery,
db_session,
)
return
def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed."
)
redis_connector_index.reset()
return
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
@@ -469,6 +697,7 @@ def try_creating_indexing_task(
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
@@ -502,13 +731,14 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
@@ -540,7 +770,6 @@ def connector_indexing_proxy_task(
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -563,7 +792,6 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -571,7 +799,6 @@ def connector_indexing_proxy_task(
task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -582,11 +809,62 @@ def connector_indexing_proxy_task(
while True:
sleep(5)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
try:
if job.status == "error":
ignore_exitcode = False
exit_code: int | None = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates successful completion
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if not ignore_exitcode:
raise RuntimeError("Spawned task exceptioned.")
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
except Exception:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
raise
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -609,79 +887,36 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
job.cancel()
job.cancel()
break
if not job.done():
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
if job.status == "error":
ignore_exitcode = False
exit_code: int | None = None
if job.process:
exit_code = job.process.exitcode
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
# even though logging clearly indicates that they completed successfully
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
else:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
job.release()
break
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -707,7 +942,7 @@ def connector_indexing_task_wrapper(
tenant_id,
is_ee,
)
except:
except Exception:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
@@ -715,13 +950,20 @@ def connector_indexing_task_wrapper(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
# There is a cloud related bug outside of our code
# where spawned tasks return with an exit code of 1.
# Unfortunately, exceptions also return with an exit code of 1,
# so just raising an exception isn't informative
# Exiting with 255 makes it possible to distinguish between normal exits
# and exceptions.
sys.exit(255)
return result
def connector_indexing_task(
index_attempt_id: int,
index_attempt_id: int | None,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
@@ -787,7 +1029,17 @@ def connector_indexing_task(
f"fence={redis_connector.stop.fence_key}"
)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_indexing_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
)
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
@@ -828,7 +1080,9 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return None
@@ -864,6 +1118,7 @@ def connector_indexing_task(
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
@@ -877,6 +1132,7 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
@@ -896,8 +1152,19 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
if attempt_found:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(
index_attempt_id, db_session, failure_reason=str(e)
)
except Exception:
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
finally:
@@ -906,7 +1173,6 @@ def connector_indexing_task(
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

View File

@@ -81,10 +81,10 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -92,7 +92,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -122,11 +122,13 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
task_logger.exception("Unexpected exception during pruning check")
finally:
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_prune_generator_task(
celery_app: Celery,
@@ -283,6 +285,7 @@ def connector_pruning_generator_task(
)
callback = IndexingCallback(
0,
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
@@ -308,7 +311,7 @@ def connector_pruning_generator_task(
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
task_logger.info(
f"Pruning set collected: "
"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
@@ -324,7 +327,7 @@ def connector_pruning_generator_task(
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. "
"RedisConnector.prune.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)

View File

@@ -60,7 +60,7 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
task_logger.debug(f"Task start: doc={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -129,16 +129,13 @@ def document_by_cc_pair_cleanup_task(
db_session.commit()
task_logger.info(
f"tenant={tenant_id} "
f"doc={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected}"
)
except SoftTimeLimitExceeded:
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
)
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
if isinstance(ex, RetryError):
@@ -157,15 +154,12 @@ def document_by_cc_pair_cleanup_task(
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"tenant={tenant_id} "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
)
task_logger.exception(f"Unexpected exception: doc={document_id}")
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
@@ -176,7 +170,7 @@ def document_by_cc_pair_cleanup_task(
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"tenant={tenant_id} doc={document_id}"
f"doc={document_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
# delete the cc pair relationship now and let reconciliation clean it up

View File

@@ -1,3 +1,4 @@
import time
import traceback
from datetime import datetime
from datetime import timezone
@@ -19,6 +20,7 @@ from tenacity import RetryError
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -86,13 +88,14 @@ logger = setup_logger()
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -100,7 +103,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
@@ -156,11 +159,15 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
task_logger.exception("Unexpected exception during vespa metadata sync")
finally:
if lock_beat.owned():
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
return True
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
@@ -630,15 +637,23 @@ def monitor_ccpair_indexing_taskset(
if not payload:
return
elapsed_started_str = None
if payload.started:
elapsed_started = datetime.now(timezone.utc) - payload.started
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"Connector indexing progress: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
@@ -709,11 +724,14 @@ def monitor_ccpair_indexing_taskset(
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"Connector indexing finished: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
redis_connector_index.reset()
@@ -730,6 +748,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False if it exited early to prevent overlap
"""
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -758,32 +777,43 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
n_permissions_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
n_external_group_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(prefetched)} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
)
# scan and monitor activity to completion
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
@@ -794,28 +824,21 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -824,6 +847,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if lock_beat.owned():
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.debug(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
return True
@@ -873,13 +898,9 @@ def vespa_metadata_sync_task(
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
task_logger.info(
f"tenant={tenant_id} doc={document_id} action=sync chunks={chunks_affected}"
)
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
except SoftTimeLimitExceeded:
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
)
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(
@@ -897,14 +918,13 @@ def vespa_metadata_sync_task(
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"tenant={tenant_id} "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
f"Unexpected exception during vespa metadata sync: doc={document_id}"
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64

View File

@@ -11,8 +11,10 @@ from onyx.background.indexing.tracer import OnyxTracer
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
@@ -34,6 +36,7 @@ from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import global_version
logger = setup_logger()
@@ -88,6 +91,35 @@ def _get_connector_runner(
)
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
cleaned_batch = []
for doc in doc_batch:
cleaned_doc = doc.model_copy()
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link and "\x00" in section.link:
logger.warning(
f"NUL characters found in document link for document: {cleaned_doc.id}"
)
section.link = section.link.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
@@ -236,7 +268,9 @@ def _run_indexing(
)
batch_description = []
for doc in doc_batch:
doc_batch_cleaned = strip_null_characters(doc_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
@@ -256,15 +290,15 @@ def _run_indexing(
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
all_connector_doc_ids.update(doc.id for doc in doc_batch)
document_count += len(doc_batch_cleaned)
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
@@ -274,7 +308,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch))
callback.progress("_run_indexing", len(doc_batch_cleaned))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -396,6 +430,15 @@ def _run_indexing(
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
properties=None,
db_session=db_session,
)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"

View File

@@ -31,6 +31,8 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
@@ -53,6 +55,9 @@ from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
@@ -117,6 +122,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -356,6 +362,31 @@ def stream_chat_message_objects(
if not persona:
raise RuntimeError("No persona specified or found for chat session")
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
user=user,
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
db_session=db_session,
)
update_user_assistant_milestone(
milestone=multi_assistant_milestone,
user_id=str(user.id) if user else NO_AUTH_USER_ID,
assistant_id=persona.id,
db_session=db_session,
)
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
milestone=multi_assistant_milestone,
db_session=db_session,
)
if just_hit_multi_assistant_milestone:
mt_cloud_telemetry(
distinct_id=tenant_id,
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
properties=None,
)
# If a prompt override is specified via the API, use that with highest priority
# but for saving it, we are just mapping it to an existing prompt
prompt_id = new_msg_req.prompt_id

View File

@@ -65,7 +65,7 @@ class CitationProcessor:
# Handle code blocks without language tags
if "`" in self.curr_segment:
if self.curr_segment.endswith("`"):
return
pass
elif "```" in self.curr_segment:
piece_that_comes_after = self.curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):

View File

@@ -1,6 +1,7 @@
import json
import os
import urllib.parse
from typing import cast
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
@@ -57,6 +58,9 @@ SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# Default request timeout, mostly used by connectors
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
# restrict access to Onyx to only users with emails from those domains.
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
@@ -91,6 +95,7 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
# If set, Onyx will listen to the `expires_at` returned by the identity
@@ -144,6 +149,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
@@ -174,11 +180,33 @@ try:
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
if _rate_limit_window_seconds_str is not None:
try:
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
except ValueError:
pass
RATE_LIMIT_MAX_REQUESTS: int | None = None
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
if _rate_limit_max_requests_str is not None:
try:
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
except ValueError:
pass
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
@@ -342,12 +370,17 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)
# Typically set to http://localhost:3000 for OAuth connector development
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
# Egnyte specific configs
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
# Linear specific configs
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -483,6 +516,21 @@ SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
# allow for custom error messages for different errors returned by litellm
# for example, can specify: {"Violated content safety policy": "EVIL REQUEST!!!"}
# to make it so that if an LLM call returns an error containing "Violated content safety policy"
# the end user will see "EVIL REQUEST!!!" instead of the default error message.
_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = os.environ.get(
"LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS", ""
)
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS: dict[str, str] | None = None
try:
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = cast(
dict[str, str], json.loads(_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS)
)
except json.JSONDecodeError:
pass
#####
# Enterprise Edition Configs
#####

View File

@@ -63,6 +63,10 @@ LANGUAGE_CHAT_NAMING_HINT = (
or "The name of the conversation must be in the same language as the user query."
)
# Number of prompts each persona should have
NUM_PERSONA_PROMPTS = 4
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.

View File

@@ -15,6 +15,9 @@ ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
NO_AUTH_USER_ID = "__no_auth_user__"
NO_AUTH_USER_EMAIL = "anonymous@onyx.app"
# For chunking/processing chunks
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
@@ -33,6 +36,8 @@ DISABLED_GEN_AI_MSG = (
DEFAULT_PERSONA_ID = 0
DEFAULT_CC_PAIR_ID = 1
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
@@ -46,6 +51,7 @@ POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
SSL_CERT_FILE = "bundle.pem"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
@@ -77,6 +83,9 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
@@ -133,6 +142,7 @@ class DocumentSource(str, Enum):
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
EGNYTE = "egnyte"
AIRTABLE = "airtable"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -170,6 +180,10 @@ class AuthType(str, Enum):
CLOUD = "cloud"
# Special characters for password validation
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
class SessionType(str, Enum):
CHAT = "Chat"
SEARCH = "Search"
@@ -207,9 +221,23 @@ class FileOrigin(str, Enum):
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
MY_DOCUMENTS = "my_documents"
OTHER = "other"
class MilestoneRecordType(str, Enum):
TENANT_CREATED = "tenant_created"
USER_SIGNED_UP = "user_signed_up"
MULTIPLE_USERS = "multiple_users"
VISITED_ADMIN_PAGE = "visited_admin_page"
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"
RAN_QUERY = "ran_query"
MULTIPLE_ASSISTANTS = "multiple_assistants"
CREATED_ASSISTANT = "created_assistant"
CREATED_ONYX_BOT = "created_onyx_bot"
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
@@ -252,6 +280,11 @@ class OnyxRedisLocks:
SLACK_BOT_LOCK = "da_lock:slack_bot"
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
class OnyxCeleryPriority(int, Enum):

View File

@@ -0,0 +1,268 @@
from io import BytesIO
from typing import Any
import requests
from pyairtable import Api as AirtableApi
from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: all are made lowercase to avoid case sensitivity issues
# these are the field types that are considered metadata rather
# than sections
_METADATA_FIELD_TYPES = {
"singlecollaborator",
"collaborator",
"createdby",
"singleselect",
"multipleselects",
"checkbox",
"date",
"datetime",
"email",
"phone",
"url",
"number",
"currency",
"duration",
"percent",
"rating",
"createdtime",
"lastmodifiedtime",
"autonumber",
"rollup",
"lookup",
"count",
"formula",
"date",
}
class AirtableClientNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__("Airtable Client is not set up, was load_credentials called?")
class AirtableConnector(LoadConnector):
def __init__(
self,
base_id: str,
table_name_or_id: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.batch_size = batch_size
self.airtable_client: AirtableApi | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
return None
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
"""
Extract value(s) from a field regardless of its type.
Returns either a single string or list of strings for attachments.
"""
if field_info is None:
return []
# skip references to other records for now (would need to do another
# request to get the actual record name/type)
# TODO: support this
if field_type == "multipleRecordLinks":
return []
if field_type == "multipleAttachments":
attachment_texts: list[str] = []
for attachment in field_info:
url = attachment.get("url")
filename = attachment.get("filename", "")
if not url:
continue
@retry(
tries=5,
delay=1,
backoff=2,
max_delay=10,
)
def get_attachment_with_retry(url: str) -> bytes | None:
attachment_response = requests.get(url)
if attachment_response.status_code == 200:
return attachment_response.content
return None
attachment_content = get_attachment_with_retry(url)
if attachment_content:
try:
file_ext = get_file_ext(filename)
attachment_text = extract_file_text(
BytesIO(attachment_content),
filename,
break_on_unprocessable=False,
extension=file_ext,
)
if attachment_text:
attachment_texts.append(f"{filename}:\n{attachment_text}")
except Exception as e:
logger.warning(
f"Failed to process attachment {filename}: {str(e)}"
)
return attachment_texts
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
combined = []
collab_name = field_info.get("name")
collab_email = field_info.get("email")
if collab_name:
combined.append(collab_name)
if collab_email:
combined.append(f"({collab_email})")
return [" ".join(combined) if combined else str(field_info)]
if isinstance(field_info, list):
return [str(item) for item in field_info]
return [str(field_info)]
def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata."""
return field_type.lower() in _METADATA_FIELD_TYPES
def _process_field(
self,
field_name: str,
field_info: Any,
field_type: str,
table_id: str,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
"""
Process a single Airtable field and return sections or metadata.
Args:
field_name: Name of the field
field_info: Raw field information from Airtable
field_type: Airtable field type
Returns:
(list of Sections, dict of metadata)
"""
if field_info is None:
return [], {}
# Get the value(s) for the field
field_values = self._get_field_value(field_info, field_type)
if len(field_values) == 0:
return [], {}
# Determine if it should be metadata or a section
if self._should_be_metadata(field_type):
if len(field_values) > 1:
return [], {field_name: field_values}
return [], {field_name: field_values[0]}
# Otherwise, create relevant sections
sections = [
Section(
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
text=(
f"{field_name}:\n"
"------------------------\n"
f"{text}\n"
"------------------------"
),
)
for text in field_values
]
return sections, {}
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from the table.
NOTE: Airtable does not support filtering by time updated, so
we have to fetch all records every time.
"""
if not self.airtable_client:
raise AirtableClientNotSetUpError()
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
table_id = table.id
# due to https://community.airtable.com/t5/development-apis/pagination-returns-422-error/td-p/54778,
# we can't user the `iterate()` method - we need to get everything up front
# this also means we can't handle tables that won't fit in memory
records = table.all()
table_schema = table.schema()
# have to get the name from the schema, since the table object will
# give back the ID instead of the name if the ID is used to create
# the table object
table_name = table_schema.name
primary_field_name = None
# Find a primary field from the schema
for field in table_schema.fields:
if field.id == table_schema.primary_field_id:
primary_field_name = field.name
break
record_documents: list[Document] = []
for record in records:
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
# Possibly retrieve the primary field's value
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
)
for field_schema in table_schema.fields:
field_name = field_schema.name
field_val = fields.get(field_name)
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
record_document = Document(
id=f"airtable__{record_id}",
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
)
record_documents.append(record_document)
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []
if record_documents:
yield record_documents

View File

@@ -56,6 +56,23 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -64,7 +81,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
is_cloud: bool,
space: str = "",
page_id: str = "",
index_recursively: bool = True,
index_recursively: bool = False,
cql_query: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
@@ -82,23 +99,25 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
# if nothing is provided, we will fetch all pages
cql_page_query = "type=page"
"""
If nothing is provided, we default to fetching all pages
Only one or none of the following options should be specified so
the order shouldn't matter
However, we use elif to ensure that only of the following is enforced
"""
base_cql_page_query = "type=page"
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
base_cql_page_query = cql_query
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')"
else:
cql_page_query += f" and id='{page_id}'"
base_cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
uri_safe_space = quote(space)
base_cql_page_query += f" and space='{uri_safe_space}'"
self.cql_page_query = cql_page_query
self.cql_time_filter = ""
self.base_cql_page_query = base_cql_page_query
self.cql_label_filter = ""
if labels_to_skip:
@@ -126,6 +145,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _construct_page_query(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
page_query = self.base_cql_page_query + self.cql_label_filter
# Add time filters
if start:
formatted_start_time = datetime.fromtimestamp(
start, tz=self.timezone
).strftime("%Y-%m-%d %H:%M")
page_query += f" and lastmodified >= '{formatted_start_time}'"
if end:
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
comment_string = ""
@@ -205,11 +251,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
metadata=doc_metadata,
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
def _fetch_document_batches(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
page_query = self._construct_page_query(start, end)
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
@@ -228,11 +278,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(confluence_page_id)
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
@@ -248,17 +297,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_document_batches()
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def poll_source(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)
def retrieve_all_slim_documents(
self,
@@ -269,7 +313,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -294,10 +338,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=page_perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
cql=attachment_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):

View File

@@ -153,7 +153,7 @@ class OnyxConfluence(Confluence):
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
if e.response is not None and e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)

View File

@@ -6,6 +6,7 @@ from typing import TypeVar
from dateutil.parser import parse
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.connectors.models import BasicExpertInfo
from onyx.utils.text_processing import is_valid_email
@@ -71,3 +72,10 @@ def process_in_batches(
def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
if CONNECTOR_LOCALHOST_OVERRIDE:
# Used for development
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"

View File

@@ -190,7 +190,7 @@ class DiscourseConnector(PollConnector):
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
page = 1
page = 0
while topic_ids := self._get_latest_topics(start, end, page):
doc_batch: list[Document] = []
for topic_id in topic_ids:

View File

@@ -3,20 +3,19 @@ import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from logging import Logger
from typing import Any
from typing import cast
from typing import IO
from urllib.parse import quote
import requests
from retry import retry
from pydantic import Field
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
@@ -33,53 +32,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
logger = setup_logger()
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
_TIMEOUT = 60
def _request_with_retries(
method: str,
url: str,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = _TIMEOUT,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method,
url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code != 403:
logger.exception(
f"Failed to call Egnyte API.\n"
f"URL: {url}\n"
f"Headers: {headers}\n"
f"Data: {data}\n"
f"Params: {params}"
)
raise e
return response
return _make_request()
def _parse_last_modified(last_modified: str) -> datetime:
@@ -166,6 +125,15 @@ def _process_egnyte_file(
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
egnyte_domain: str = Field(
title="Egnyte Domain",
description=(
"The domain for the Egnyte instance "
"(e.g. 'company' for company.egnyte.com)"
),
)
def __init__(
self,
folder_path: str | None = None,
@@ -181,18 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
return DocumentSource.EGNYTE
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
if EGNYTE_LOCALHOST_OVERRIDE:
base_domain = EGNYTE_LOCALHOST_OVERRIDE
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
@@ -201,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
)
@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
@@ -222,7 +198,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = _request_with_retries(
response = request_with_retries(
method="POST",
url=url,
data=data,
@@ -236,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"domain": oauth_kwargs.egnyte_domain,
"access_token": token_data["access_token"],
}
@@ -260,9 +236,10 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
"list_content": True,
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
response = _request_with_retries(
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
url_encoded_path = quote(path or "")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
response = request_with_retries(
method="GET", url=url, headers=headers, params=params
)
if not response.ok:
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
@@ -315,12 +292,12 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
response = _request_with_retries(
url_encoded_path = quote(file["path"])
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
response = request_with_retries(
method="GET",
url=url,
headers=headers,
timeout=_TIMEOUT,
stream=True,
)

View File

@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
from onyx.connectors.blob.connector import BlobStorageConnector
@@ -103,6 +104,7 @@ def identify_connector_class(
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -4,6 +4,7 @@ from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -249,17 +250,36 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
return new_creds_dict
def _get_all_user_emails(self) -> list[str]:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
"""
List all user emails if we are on a Google Workspace domain.
If the domain is gmail.com, or if we attempt to call the Admin SDK and
get a 404, fall back to using the single user.
"""
try:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
except HttpError as e:
if e.resp.status == 404:
logger.warning(
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
"with no Workspace domain. Falling back to single user."
)
return [self.primary_admin_email]
raise
except Exception:
raise
def _fetch_threads(
self,

View File

@@ -8,6 +8,7 @@ from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
@@ -20,6 +21,7 @@ from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.google_auth import get_google_creds
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
@@ -41,6 +43,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
@@ -286,13 +289,30 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
logger.info(f"Impersonating user {user_email}")
drive_service = get_drive_service(self.creds, user_email)
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
if e.status_code == 401:
# fail gracefully, let the other impersonations continue
# one user without access shouldn't block the entire connector
logger.exception(
f"User '{user_email}' does not have access to the drive APIs."
)
return
raise
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - include_my_drives is true
# - the current user's email is in the requested emails
if self.include_my_drives or user_email in self._requested_my_drive_emails:
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
@@ -303,6 +323,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'")
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
@@ -314,6 +335,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
@@ -344,6 +366,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
# checkpoint - we've found all users and drives, now time to actually start
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
logger.debug(f"Users: {all_org_emails}")
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
logger.debug(f"Drives: {drive_ids_to_retrieve}")
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
logger.debug(f"Folders: {folder_ids_to_retrieve}")
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
@@ -380,6 +411,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
drive_service = get_drive_service(self.creds, self.primary_admin_email)
if self.include_files_shared_with_me or self.include_my_drives:
logger.info(
f"Getting shared files/my drive files for OAuth "
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
f"include_my_drives={self.include_my_drives}, "
f"include_shared_drives={self.include_shared_drives}."
f"Using '{self.primary_admin_email}' as the account."
)
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
@@ -412,6 +450,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
drive_ids_to_retrieve = all_drive_ids
for drive_id in drive_ids_to_retrieve:
logger.info(
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
)
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
@@ -425,6 +466,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
for folder_id in remaining_folders:
logger.info(
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
)
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,

View File

@@ -2,6 +2,8 @@ import abc
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
class OAuthConnector(BaseConnector):
class AdditionalOauthKwargs(BaseModel):
# if overridden, all fields should be str type
pass
@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
@@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -7,16 +7,23 @@ from typing import cast
import requests
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import LINEAR_CLIENT_ID
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
logger = setup_logger()
@@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
)
class LinearConnector(LoadConnector, PollConnector):
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@@ -65,8 +72,68 @@ class LinearConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.linear_api_key: str | None = None
@classmethod
def oauth_id(cls) -> DocumentSource:
return DocumentSource.LINEAR
@classmethod
def oauth_authorization_url(
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
) -> str:
if not LINEAR_CLIENT_ID:
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
return (
f"https://linear.app/oauth/authorize"
f"?client_id={LINEAR_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&response_type=code"
f"&scope=read"
f"&state={state}"
)
@classmethod
def oauth_code_to_token(
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
) -> dict[str, Any]:
data = {
"code": code,
"redirect_uri": get_oauth_callback_uri(
base_domain, DocumentSource.LINEAR.value
),
"client_id": LINEAR_CLIENT_ID,
"client_secret": LINEAR_CLIENT_SECRET,
"grant_type": "authorization_code",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = request_with_retries(
method="POST",
url="https://api.linear.app/oauth/token",
data=data,
headers=headers,
backoff=0,
delay=0.1,
)
if not response.ok:
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
token_data = response.json()
return {
"access_token": token_data["access_token"],
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.linear_api_key = cast(str, credentials["linear_api_key"])
if "linear_api_key" in credentials:
self.linear_api_key = cast(str, credentials["linear_api_key"])
elif "access_token" in credentials:
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
else:
# May need to handle case in the future if the OAuth flow expires
raise ConnectorMissingCredentialError("Linear")
return None
def _process_issues(

View File

@@ -1,11 +1,7 @@
import os
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -19,24 +15,25 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.salesforce.utils import extract_dict_text
from onyx.connectors.salesforce.doc_conversion import extract_section
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
# TODO: this connector does not work well at large scales
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
ID_PREFIX = "SALESFORCE_"
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_ID_PREFIX = "SALESFORCE_"
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
@@ -44,200 +41,170 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
requested_objects: list[str] = [],
) -> None:
self.batch_size = batch_size
self.sf_client: Salesforce | None = None
self._sf_client: Salesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else DEFAULT_PARENT_OBJECT_TYPES
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.sf_client = Salesforce(
def load_credentials(
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
)
return None
def _get_sf_type_object_json(self, type_name: str) -> Any:
if self.sf_client is None:
@property
def sf_client(self) -> Salesforce:
if self._sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
sf_object = SFType(
type_name, self.sf_client.session_id, self.sf_client.sf_instance
)
return sf_object.describe()
return self._sf_client
def _get_name_from_id(self, id: str) -> str:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
try:
user_object_info = self.sf_client.query(
f"SELECT Name FROM User WHERE Id = '{id}'"
)
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
return name
except Exception:
logger.warning(f"Couldnt find name for object id: {id}")
return "Null User"
def _extract_primary_owners(
self, sf_object: SalesforceObject
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
return None
if not (last_modified_by := get_record(last_modified_by_id)):
return None
if not (last_modified_by_name := last_modified_by.data.get("Name")):
return None
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
return primary_owners
def _convert_object_instance_to_document(
self, object_dict: dict[str, Any]
self, sf_object: SalesforceObject
) -> Document:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_dict = sf_object.data
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
base_url = f"https://{self.sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_object_text = extract_dict_text(object_dict)
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_primary_owners = [
BasicExpertInfo(
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
)
]
sections = [extract_section(sf_object, base_url)]
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(extract_section(child_object, base_url))
doc = Document(
id=onyx_salesforce_id,
sections=[Section(link=extracted_link, text=extracted_object_text)],
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=extracted_primary_owners,
primary_owners=self._extract_primary_owners(sf_object),
metadata={},
)
return doc
def _is_valid_child_object(self, child_relationship: dict) -> bool:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = self._get_sf_type_object_json(sf_type)
if not object_description["queryable"]:
return False
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
result = self.sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
children_objects: list[dict] = []
for child_relationship in object_description["childRelationships"]:
if self._is_valid_child_object(child_relationship):
children_objects.append(
{
"relationship_name": child_relationship["relationshipName"],
"object_type": child_relationship["childSObject"],
}
)
return children_objects
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
fields = [
field.get("name")
for field in object_description["fields"]
if field.get("type", "base64") != "base64"
]
return fields
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
"""
This function takes in an object_type and generates query(s) designed to grab
information associated to objects of that type.
It does that by getting all the fields of the parent object type.
Then it gets all the child objects of that object type and all the fields of
those children as well.
"""
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
query = f"SELECT {', '.join(parent_fields)}"
for child_object_dict in child_sf_types:
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
query += f"\n FROM {parent_sf_type}"
yield query
query = "SELECT Id" + query_addition
else:
query += query_addition
query += f"\n FROM {parent_sf_type}"
yield query
def _fetch_from_salesforce(
self,
start: datetime | None = None,
end: datetime | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
init_db()
all_object_types: set[str] = set(self.parent_object_list)
doc_batch: list[Document] = []
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
logger.debug(f"Parent object types: {self.parent_object_list}")
# This takes like 20 seconds
for parent_object_type in self.parent_object_list:
logger.debug(f"Processing: {parent_object_type}")
query_results: dict = {}
for query in self._generate_query_per_parent_type(parent_object_type):
if start is not None and end is not None:
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
query_result = self.sf_client.query_all(query)
for record_dict in query_result["records"]:
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
logger.info(
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
child_types = get_all_children_of_sf_type(
self.sf_client, parent_object_type
)
all_object_types.update(child_types)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
)
for combined_object_dict in query_results.values():
doc_batch.append(
self._convert_object_instance_to_document(combined_object_dict)
)
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
logger.debug(f"All object types: {all_object_types}")
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
# checkpoint - we've found all object types, now time to fetch the data
logger.info("Starting to fetch CSVs for all object types")
# This takes like 30 minutes first time and <2 minutes for updates
object_type_to_csv_path = fetch_all_csvs_in_parallel(
sf_client=self.sf_client,
object_types=all_object_types,
start=start,
end=end,
)
updated_ids: set[str] = set()
# This takes like 10 seconds
# This is for testing the rest of the functionality if data has
# already been fetched and put in sqlite
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
# for object_type in self.parent_object_list:
# updated_ids.update(list(find_ids_by_type(object_type)))
# This takes 10-70 minutes first time (idk why the range is so big)
total_types = len(object_type_to_csv_path)
logger.info(f"Starting to process {total_types} object types")
for i, (object_type, csv_paths) in enumerate(
object_type_to_csv_path.items(), 1
):
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
# If path is None, it means it failed to fetch the csv
if csv_paths is None:
continue
# Go through each csv path and use it to update the db
for csv_path in csv_paths:
logger.debug(f"Updating {object_type} with {csv_path}")
new_ids = update_sf_db_with_csv(
object_type=object_type,
csv_download_path=csv_path,
)
updated_ids.update(new_ids)
logger.debug(
f"Added {len(new_ids)} new/updated records for {object_type}"
)
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_path)
logger.info(f"Found {len(updated_ids)} total updated records")
logger.info(
f"Starting to process parent objects of types: {self.parent_object_list}"
)
docs_to_yield: list[Document] = []
docs_processed = 0
# Takes 15-20 seconds per batch
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
updated_ids=list(updated_ids),
parent_types=self.parent_object_list,
):
logger.info(
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
)
for parent_id in parent_id_batch:
if not (parent_object := get_record(parent_id, parent_type)):
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
docs_to_yield.append(
self._convert_object_instance_to_document(parent_object)
)
docs_processed += 1
if len(docs_to_yield) >= self.batch_size:
yield docs_to_yield
docs_to_yield = []
yield docs_to_yield
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_salesforce()
@@ -245,26 +212,20 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
return self._fetch_from_salesforce(start=start, end=end)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
perm_sync_data={},
)
for instance_dict in query_result["records"]
@@ -274,9 +235,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
if __name__ == "__main__":
connector = SalesforceConnector(
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
)
import time
connector = SalesforceConnector(requested_objects=["Account"])
connector.load_credentials(
{
@@ -285,5 +246,20 @@ if __name__ == "__main__":
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))
start_time = time.time()
doc_count = 0
section_count = 0
text_count = 0
for doc_batch in connector.load_from_state():
doc_count += len(doc_batch)
print(f"doc_count: {doc_count}")
for doc in doc_batch:
section_count += len(doc.sections)
for section in doc.sections:
text_count += len(section.text)
end_time = time.time()
print(f"Doc count: {doc_count}")
print(f"Section count: {section_count}")
print(f"Text count: {text_count}")
print(f"Time taken: {end_time - start_time}")

View File

@@ -0,0 +1,156 @@
import re
from collections import OrderedDict
from onyx.connectors.models import Section
from onyx.connectors.salesforce.utils import SalesforceObject
# All of these types of keys are handled by specific fields in the doc
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
def _clean_salesforce_dict(data: dict | list) -> dict | list:
"""Clean and transform Salesforce API response data by recursively:
1. Extracting records from the response if present
2. Merging attributes into the main dictionary
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
4. Removing '__c' suffix from custom field names
5. Removing None values and empty containers
Args:
data: A dictionary or list from Salesforce API response
Returns:
Cleaned dictionary or list with transformed keys and filtered values
"""
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
# remove the custom object indicator for display
if "__c" in key:
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
# Only add non-empty dictionaries or lists
if filtered_value:
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
# Only add non-empty dictionaries or lists
if filtered_item:
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
"""Convert a nested dictionary or list into a human-readable string format.
Recursively traverses the data structure and formats it with:
- Key-value pairs on separate lines
- Nested structures indented for readability
- Lists and dictionaries handled with appropriate formatting
Args:
data: The dictionary or list to convert
indent: Number of spaces to indent (default: 0)
Returns:
A formatted string representation of the data structure
"""
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent + 2))
return "\n".join(result)
def _extract_dict_text(raw_dict: dict) -> str:
"""Extract text from a Salesforce API response dictionary by:
1. Cleaning the dictionary
2. Converting the cleaned dictionary to natural language
"""
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_for_dict = _json_to_natural_language(processed_dict)
return natural_language_for_dict
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
def _field_value_is_child_object(field_value: dict) -> bool:
"""
Checks if the field value is a child object.
"""
return (
isinstance(field_value, OrderedDict)
and "records" in field_value.keys()
and isinstance(field_value["records"], list)
and len(field_value["records"]) > 0
and "Id" in field_value["records"][0].keys()
)
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
"""
This goes through the salesforce_object and extracts the top level fields as a Section.
It also goes through the child objects and extracts them as Sections.
"""
top_level_dict = {}
child_object_sections = []
for field_name, field_value in salesforce_object.items():
# If the field value is not a child object, add it to the top level dict
# to turn into text for the top level section
if not _field_value_is_child_object(field_value):
top_level_dict[field_name] = field_value
continue
# If the field value is a child object, extract the child objects and add them as sections
for record in field_value["records"]:
child_object_id = record["Id"]
child_object_sections.append(
Section(
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
link=f"{base_url}/{child_object_id}",
)
)
top_level_id = salesforce_object["Id"]
top_level_section = Section(
text=_extract_dict_text(top_level_dict),
link=f"{base_url}/{top_level_id}",
)
return [top_level_section, *child_object_sections]

View File

@@ -0,0 +1,210 @@
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any
from pytz import UTC
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
from onyx.connectors.salesforce.utils import get_object_type_path
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _build_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
if start is None or end is None:
return ""
start_datetime = datetime.fromtimestamp(start, UTC)
end_datetime = datetime.fromtimestamp(end, UTC)
return (
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
f"AND LastModifiedDate < {end_datetime.isoformat()}"
)
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
return sf_object.describe()
def _is_valid_child_object(
sf_client: Salesforce, child_relationship: dict[str, Any]
) -> bool:
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = _get_sf_type_object_json(sf_client, sf_type)
if not object_description["queryable"]:
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
object_description = _get_sf_type_object_json(sf_client, sf_type)
child_object_types = set()
for child_relationship in object_description["childRelationships"]:
if _is_valid_child_object(sf_client, child_relationship):
logger.debug(
f"Found valid child object {child_relationship['childSObject']}"
)
child_object_types.add(child_relationship["childSObject"])
return child_object_types
def _get_all_queryable_fields_of_sf_type(
sf_client: Salesforce,
sf_type: str,
) -> list[str]:
object_description = _get_sf_type_object_json(sf_client, sf_type)
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
compound_field_names: set[str] = set()
for field in fields:
if compound_field_name := field.get("compoundFieldName"):
compound_field_names.add(compound_field_name)
if field.get("type", "base64") == "base64":
continue
if field_name := field.get("name"):
valid_fields.add(field_name)
return list(valid_fields - compound_field_names)
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
"""
Send a small query to check if the object type is empty so we don't
perform extra bulk queries
"""
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
if "OPERATION_TOO_LARGE" not in str(e):
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
return True
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
# Check if the csv already exists
if os.path.exists(get_object_type_path(sf_type)):
existing_csvs = [
os.path.join(get_object_type_path(sf_type), f)
for f in os.listdir(get_object_type_path(sf_type))
if f.endswith(".csv")
]
# If the csv already exists, return the path
# This is likely due to a previous run that failed
# after downloading the csv but before the data was
# written to the db
if existing_csvs:
return existing_csvs
return None
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
return query
def _bulk_retrieve_from_salesforce(
sf_client: Salesforce,
sf_type: str,
time_filter: str,
) -> tuple[str, list[str] | None]:
if not _check_if_object_type_is_empty(sf_client, sf_type):
return sf_type, None
if existing_csvs := _check_for_existing_csvs(sf_type):
return sf_type, existing_csvs
query = _build_bulk_query(sf_client, sf_type, time_filter)
bulk_2_handler = SFBulk2Handler(
session_id=sf_client.session_id,
bulk2_url=sf_client.bulk2_url,
proxies=sf_client.proxies,
session=sf_client.session,
)
bulk_2_type = SFBulk2Type(
object_name=sf_type,
bulk2_url=bulk_2_handler.bulk2_url,
headers=bulk_2_handler.headers,
session=bulk_2_handler.session,
)
logger.info(f"Downloading {sf_type}")
logger.info(f"Query: {query}")
try:
# This downloads the file to a file in the target path with a random name
results = bulk_2_type.download(
query=query,
path=get_object_type_path(sf_type),
max_records=1000000,
)
all_download_paths = [result["file"] for result in results]
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
return sf_type, all_download_paths
except Exception as e:
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
return sf_type, None
def fetch_all_csvs_in_parallel(
sf_client: Salesforce,
object_types: set[str],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
) -> dict[str, list[str] | None]:
"""
Fetches all the csvs in parallel for the given object types
Returns a dict of (sf_type, full_download_path)
"""
time_filter = _build_time_filter_for_salesforce(start, end)
time_filter_for_each_object_type = {}
# We do this outside of the thread pool executor because this requires
# a database connection and we don't want to block the thread pool
# executor from running
for sf_type in object_types:
"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
if has_at_least_one_object_of_type(sf_type):
time_filter_for_each_object_type[sf_type] = time_filter
else:
time_filter_for_each_object_type[sf_type] = ""
# Run the bulk retrieve in parallel
with ThreadPoolExecutor() as executor:
results = executor.map(
lambda object_type: _bulk_retrieve_from_salesforce(
sf_client=sf_client,
sf_type=object_type,
time_filter=time_filter_for_each_object_type[object_type],
),
object_types,
)
return dict(results)

View File

@@ -0,0 +1,209 @@
import csv
import shelve
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
get_child_to_parent_shelf_path,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
get_parent_to_child_shelf_path,
)
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _update_relationship_shelves(
child_id: str,
parent_ids: set[str],
) -> None:
"""Update the relationship shelf when a record is updated."""
try:
# Convert child_id to string once
str_child_id = str(child_id)
# First update child to parent mapping
with shelve.open(
get_child_to_parent_shelf_path(),
flag="c",
protocol=None,
writeback=True,
) as child_to_parent_db:
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
child_to_parent_db[str_child_id] = list(parent_ids)
# Calculate differences outside the next context manager
parent_ids_to_remove = old_parent_ids - parent_ids
parent_ids_to_add = parent_ids - old_parent_ids
# Only sync once at the end
child_to_parent_db.sync()
# Then update parent to child mapping in a single transaction
if not parent_ids_to_remove and not parent_ids_to_add:
return
with shelve.open(
get_parent_to_child_shelf_path(),
flag="c",
protocol=None,
writeback=True,
) as parent_to_child_db:
# Process all removals first
for parent_id in parent_ids_to_remove:
str_parent_id = str(parent_id)
existing_children = set(parent_to_child_db.get(str_parent_id, []))
if str_child_id in existing_children:
existing_children.remove(str_child_id)
parent_to_child_db[str_parent_id] = list(existing_children)
# Then process all additions
for parent_id in parent_ids_to_add:
str_parent_id = str(parent_id)
existing_children = set(parent_to_child_db.get(str_parent_id, []))
existing_children.add(str_child_id)
parent_to_child_db[str_parent_id] = list(existing_children)
# Single sync at the end
parent_to_child_db.sync()
except Exception as e:
logger.error(f"Error updating relationship shelves: {e}")
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
raise
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID.
Args:
parent_id: The ID of the parent object
Returns:
A set of child object IDs
"""
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
return set(parent_to_child_db.get(parent_id, []))
def update_sf_db_with_csv(
object_type: str,
csv_download_path: str,
) -> list[str]:
"""Update the SF DB with a CSV file using shelve storage."""
updated_ids = []
shelf_path = get_object_shelf_path(object_type)
# First read the CSV to get all the data
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
id = row["Id"]
parent_ids = set()
field_to_remove: set[str] = set()
# Update relationship shelves for any parent references
for field, value in row.items():
if validate_salesforce_id(value) and field != "Id":
parent_ids.add(value)
field_to_remove.add(field)
if not value:
field_to_remove.add(field)
_update_relationship_shelves(id, parent_ids)
for field in field_to_remove:
# We use this to extract the Primary Owner later
if field != "LastModifiedById":
del row[field]
# Update the main object shelf
with shelve.open(shelf_path) as object_type_db:
object_type_db[id] = row
# Update the ID-to-type mapping shelf
with shelve.open(get_id_type_shelf_path()) as id_type_db:
id_type_db[id] = object_type
updated_ids.append(id)
# os.remove(csv_download_path)
return updated_ids
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
# Look up the object type from the ID-to-type mapping
with shelve.open(get_id_type_shelf_path()) as id_type_db:
if object_id not in id_type_db:
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
return None
return id_type_db[object_id]
def get_record(
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""
Retrieve the record and return it as a SalesforceObject.
The object type will be looked up from the ID-to-type mapping shelf.
"""
if object_type is None:
if not (object_type := get_type_from_id(object_id)):
return None
shelf_path = get_object_shelf_path(object_type)
with shelve.open(shelf_path) as db:
if object_id not in db:
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
return None
data = db[object_id]
return SalesforceObject(
id=object_id,
type=object_type,
data=data,
)
def find_ids_by_type(object_type: str) -> list[str]:
"""
Find all object IDs for rows of the specified type.
"""
shelf_path = get_object_shelf_path(object_type)
try:
with shelve.open(shelf_path) as db:
return list(db.keys())
except FileNotFoundError:
return []
def get_affected_parent_ids_by_type(
updated_ids: set[str], parent_types: list[str]
) -> dict[str, set[str]]:
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
or have children in the updated_ids.
Args:
updated_ids: List of IDs that were updated
parent_types: List of object types to filter by
Returns:
A dictionary of IDs that match the criteria
"""
affected_ids_by_type: dict[str, set[str]] = {}
# Check each updated ID
for updated_id in updated_ids:
# Add the ID itself if it's of a parent type
updated_type = get_type_from_id(updated_id)
if updated_type in parent_types:
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
continue
# Get parents of this ID and add them if they're of a parent type
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
parent_ids = child_to_parent_db.get(updated_id, [])
for parent_id in parent_ids:
parent_type = get_type_from_id(parent_id)
if parent_type in parent_types:
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
return affected_ids_by_type

View File

@@ -0,0 +1,29 @@
import os
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
def get_object_shelf_path(object_type: str) -> str:
"""Get the path to the shelf file for a specific object type."""
base_path = get_object_type_path(object_type)
os.makedirs(base_path, exist_ok=True)
return os.path.join(base_path, "data.shelf")
def get_id_type_shelf_path() -> str:
"""Get the path to the ID-to-type mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
def get_parent_to_child_shelf_path() -> str:
"""Get the path to the parent-to-child mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
def get_child_to_parent_shelf_path() -> str:
"""Get the path to the child-to-parent mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")

View File

@@ -0,0 +1,737 @@
import csv
import os
import shutil
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
get_affected_parent_ids_by_type,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
update_sf_db_with_csv,
)
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
_VALID_SALESFORCE_IDS = [
"001bm00000fd9Z3AAI",
"001bm00000fdYTdAAM",
"001bm00000fdYTeAAM",
"001bm00000fdYTfAAM",
"001bm00000fdYTgAAM",
"001bm00000fdYThAAM",
"001bm00000fdYTiAAM",
"001bm00000fdYTjAAM",
"001bm00000fdYTkAAM",
"001bm00000fdYTlAAM",
"001bm00000fdYTmAAM",
"001bm00000fdYTnAAM",
"001bm00000fdYToAAM",
"500bm00000XoOxtAAF",
"500bm00000XoOxuAAF",
"500bm00000XoOxvAAF",
"500bm00000XoOxwAAF",
"500bm00000XoOxxAAF",
"500bm00000XoOxyAAF",
"500bm00000XoOxzAAF",
"500bm00000XoOy0AAF",
"500bm00000XoOy1AAF",
"500bm00000XoOy2AAF",
"500bm00000XoOy3AAF",
"500bm00000XoOy4AAF",
"500bm00000XoOy5AAF",
"500bm00000XoOy6AAF",
"500bm00000XoOy7AAF",
"500bm00000XoOy8AAF",
"500bm00000XoOy9AAF",
"500bm00000XoOyAAAV",
"500bm00000XoOyBAAV",
"500bm00000XoOyCAAV",
"500bm00000XoOyDAAV",
"500bm00000XoOyEAAV",
"500bm00000XoOyFAAV",
"500bm00000XoOyGAAV",
"500bm00000XoOyHAAV",
"500bm00000XoOyIAAV",
"003bm00000EjHCjAAN",
"003bm00000EjHCkAAN",
"003bm00000EjHClAAN",
"003bm00000EjHCmAAN",
"003bm00000EjHCnAAN",
"003bm00000EjHCoAAN",
"003bm00000EjHCpAAN",
"003bm00000EjHCqAAN",
"003bm00000EjHCrAAN",
"003bm00000EjHCsAAN",
"003bm00000EjHCtAAN",
"003bm00000EjHCuAAN",
"003bm00000EjHCvAAN",
"003bm00000EjHCwAAN",
"003bm00000EjHCxAAN",
"003bm00000EjHCyAAN",
"003bm00000EjHCzAAN",
"003bm00000EjHD0AAN",
"003bm00000EjHD1AAN",
"003bm00000EjHD2AAN",
"550bm00000EXc2tAAD",
"006bm000006kyDpAAI",
"006bm000006kyDqAAI",
"006bm000006kyDrAAI",
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
"006bm000006kyDvAAI",
"006bm000006kyDwAAI",
"006bm000006kyDxAAI",
"006bm000006kyDyAAI",
"006bm000006kyDzAAI",
"006bm000006kyE0AAI",
"006bm000006kyE1AAI",
"006bm000006kyE2AAI",
"006bm000006kyE3AAI",
"006bm000006kyE4AAI",
"006bm000006kyE5AAI",
"006bm000006kyE6AAI",
"006bm000006kyE7AAI",
"006bm000006kyE8AAI",
"006bm000006kyE9AAI",
"006bm000006kyEAAAY",
"006bm000006kyEBAAY",
"006bm000006kyECAAY",
"006bm000006kyEDAAY",
"006bm000006kyEEAAY",
"006bm000006kyEFAAY",
"006bm000006kyEGAAY",
"006bm000006kyEHAAY",
"006bm000006kyEIAAY",
"006bm000006kyEJAAY",
"005bm000009zy0TAAQ",
"005bm000009zy25AAA",
"005bm000009zy26AAA",
"005bm000009zy28AAA",
"005bm000009zy29AAA",
"005bm000009zy2AAAQ",
"005bm000009zy2BAAQ",
]
def clear_sf_db() -> None:
"""
Clears the SF DB by deleting all files in the data directory.
"""
shutil.rmtree(BASE_DATA_PATH)
def create_csv_file(
object_type: str, records: list[dict], filename: str = "test_data.csv"
) -> None:
"""
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
if not records:
return
# Get all unique fields from records
fields: set[str] = set()
for record in records:
fields.update(record.keys())
fields = set(sorted(list(fields))) # Sort for consistent order
# Create CSV file
csv_path = os.path.join(get_object_type_path(object_type), filename)
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for record in records:
writer.writerow(record)
# Update the database with the CSV
update_sf_db_with_csv(object_type, csv_path)
def create_csv_with_example_data() -> None:
"""
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
{
"Id": _VALID_SALESFORCE_IDS[3],
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
{
"Id": _VALID_SALESFORCE_IDS[4],
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
{
"Id": _VALID_SALESFORCE_IDS[5],
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
{
"Id": _VALID_SALESFORCE_IDS[6],
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
{
"Id": _VALID_SALESFORCE_IDS[7],
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"FirstName": "John",
"LastName": "Doe",
"Email": "john.doe@acme.com",
"Title": "CEO",
},
{
"Id": _VALID_SALESFORCE_IDS[41],
"FirstName": "Jane",
"LastName": "Smith",
"Email": "jane.smith@acme.com",
"Title": "CTO",
},
{
"Id": _VALID_SALESFORCE_IDS[42],
"FirstName": "Bob",
"LastName": "Johnson",
"Email": "bob.j@globex.com",
"Title": "Sales Director",
},
{
"Id": _VALID_SALESFORCE_IDS[43],
"FirstName": "Sarah",
"LastName": "Chen",
"Email": "sarah.chen@techcorp.com",
"Title": "Product Manager",
"Phone": "415-555-0101",
},
{
"Id": _VALID_SALESFORCE_IDS[44],
"FirstName": "Michael",
"LastName": "Rodriguez",
"Email": "m.rodriguez@biomed.com",
"Title": "Research Director",
"Phone": "617-555-0202",
},
{
"Id": _VALID_SALESFORCE_IDS[45],
"FirstName": "Emily",
"LastName": "Green",
"Email": "emily.g@greenenergy.com",
"Title": "Sustainability Lead",
"Phone": "503-555-0303",
},
{
"Id": _VALID_SALESFORCE_IDS[46],
"FirstName": "David",
"LastName": "Kim",
"Email": "david.kim@dataflow.com",
"Title": "Data Scientist",
"Phone": "206-555-0404",
},
{
"Id": _VALID_SALESFORCE_IDS[47],
"FirstName": "Rachel",
"LastName": "Taylor",
"Email": "r.taylor@cloudnine.com",
"Title": "Cloud Architect",
"Phone": "303-555-0505",
},
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"Name": "Acme Server Upgrade",
"Amount": 50000,
"Stage": "Prospecting",
"CloseDate": "2024-06-30",
},
{
"Id": _VALID_SALESFORCE_IDS[63],
"Name": "Globex Manufacturing Line",
"Amount": 150000,
"Stage": "Negotiation",
"CloseDate": "2024-03-15",
},
{
"Id": _VALID_SALESFORCE_IDS[64],
"Name": "Initech Software License",
"Amount": 75000,
"Stage": "Closed Won",
"CloseDate": "2024-01-30",
},
{
"Id": _VALID_SALESFORCE_IDS[65],
"Name": "TechCorp AI Implementation",
"Amount": 250000,
"Stage": "Needs Analysis",
"CloseDate": "2024-08-15",
"Probability": 60,
},
{
"Id": _VALID_SALESFORCE_IDS[66],
"Name": "BioMed Lab Equipment",
"Amount": 500000,
"Stage": "Value Proposition",
"CloseDate": "2024-09-30",
"Probability": 75,
},
{
"Id": _VALID_SALESFORCE_IDS[67],
"Name": "Green Energy Solar Project",
"Amount": 750000,
"Stage": "Proposal",
"CloseDate": "2024-07-15",
"Probability": 80,
},
{
"Id": _VALID_SALESFORCE_IDS[68],
"Name": "DataFlow Analytics Platform",
"Amount": 180000,
"Stage": "Negotiation",
"CloseDate": "2024-05-30",
"Probability": 90,
},
{
"Id": _VALID_SALESFORCE_IDS[69],
"Name": "Cloud Nine Infrastructure",
"Amount": 300000,
"Stage": "Qualification",
"CloseDate": "2024-10-15",
"Probability": 40,
},
],
}
# Create CSV files for each object type
for object_type, records in example_data.items():
create_csv_file(object_type, records)
def test_query() -> None:
"""
Tests querying functionality by verifying:
1. All expected Account IDs are found
2. Each Account's data matches what was inserted
"""
# Expected test data for verification
expected_accounts: dict[str, dict[str, str | int]] = {
_VALID_SALESFORCE_IDS[0]: {
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
_VALID_SALESFORCE_IDS[1]: {
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
_VALID_SALESFORCE_IDS[2]: {
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
_VALID_SALESFORCE_IDS[3]: {
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
_VALID_SALESFORCE_IDS[4]: {
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
_VALID_SALESFORCE_IDS[5]: {
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
_VALID_SALESFORCE_IDS[6]: {
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
_VALID_SALESFORCE_IDS[7]: {
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
}
# Get all Account IDs
account_ids = find_ids_by_type("Account")
# Verify we found all expected accounts
assert len(account_ids) == len(
expected_accounts
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
assert set(account_ids) == set(
expected_accounts.keys()
), "Found account IDs don't match expected IDs"
# Verify each account's data
for acc_id in account_ids:
combined = get_record(acc_id)
assert combined is not None, f"Could not find account {acc_id}"
expected = expected_accounts[acc_id]
# Verify account data matches
for key, value in expected.items():
value = str(value)
assert (
combined.data[key] == value
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
print("All query tests passed successfully!")
def test_upsert() -> None:
"""
Tests upsert functionality by:
1. Updating an existing account
2. Creating a new account
3. Verifying both operations were successful
"""
# Create CSV for updating an existing account and adding a new one
update_data: list[dict[str, str | int]] = [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc. Updated",
"BillingCity": "New York",
"Industry": "Technology",
"Description": "Updated company info",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "New Company Inc.",
"BillingCity": "Miami",
"Industry": "Finance",
"AnnualRevenue": 1000000,
},
]
create_csv_file("Account", update_data, "update_data.csv")
# Verify the update worked
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
assert updated_record is not None, "Updated record not found"
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
assert (
updated_record.data["Description"] == "Updated company info"
), "Description not added"
# Verify the new record was created
new_record = get_record(_VALID_SALESFORCE_IDS[2])
assert new_record is not None, "New record not found"
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
print("All upsert tests passed successfully!")
def test_relationships() -> None:
"""
Tests relationship shelf updates and queries by:
1. Creating test data with relationships
2. Verifying the relationships are correctly stored
3. Testing relationship queries
"""
# Create test data for each object type
test_data: dict[str, list[dict[str, str | int]]] = {
"Case": [
{
"Id": _VALID_SALESFORCE_IDS[13],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 1",
},
{
"Id": _VALID_SALESFORCE_IDS[14],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 2",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[48],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Name": "Test Opportunity",
"Amount": 100000,
}
],
}
# Create and update CSV files for each object type
for object_type, records in test_data.items():
create_csv_file(object_type, records, "relationship_test.csv")
# Test relationship queries
# All these objects should be children of Acme Inc.
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
assert (
_VALID_SALESFORCE_IDS[62] in child_ids
), "Opportunity not found in relationship"
# Test querying relationships for a different account (should be empty)
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert (
len(other_account_children) == 0
), "Expected no children for different account"
print("All relationship tests passed successfully!")
def test_account_with_children() -> None:
"""
Tests querying all accounts and retrieving their child objects.
This test verifies that:
1. All accounts can be retrieved
2. Child objects are correctly linked
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = find_ids_by_type("Account")
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
for account_id in account_ids:
account = get_record(account_id)
assert account is not None, f"Could not find account {account_id}"
# Get all child objects
child_ids = get_child_ids(account_id)
# For Acme Inc., verify specific relationships
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
assert (
len(child_ids) == 4
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
# Get all child records
child_records = []
for child_id in child_ids:
child_record = get_record(child_id)
if child_record is not None:
child_records.append(child_record)
# Verify Cases
cases = [r for r in child_records if r.type == "Case"]
assert (
len(cases) == 2
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
case_subjects = {case.data["Subject"] for case in cases}
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
# Verify Contacts
contacts = [r for r in child_records if r.type == "Contact"]
assert (
len(contacts) == 1
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
contact = contacts[0]
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
# Verify Opportunities
opportunities = [r for r in child_records if r.type == "Opportunity"]
assert (
len(opportunities) == 1
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
opportunity = opportunities[0]
assert (
opportunity.data["Name"] == "Test Opportunity"
), "Opportunity name mismatch"
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
print("All account with children tests passed successfully!")
def test_relationship_updates() -> None:
"""
Tests that relationships are properly updated when a child object's parent reference changes.
This test verifies:
1. Initial relationship is created correctly
2. When parent reference is updated, old relationship is removed
3. New relationship is created correctly
"""
# Create initial test data - Contact linked to Acme Inc.
initial_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", initial_contact, "initial_contact.csv")
# Verify initial relationship
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] in acme_children
), "Initial relationship not created"
# Update contact to be linked to Globex Corp instead
updated_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[1],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", updated_contact, "updated_contact.csv")
# Verify old relationship is removed
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] not in acme_children
), "Old relationship not removed"
# Verify new relationship is created
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
print("All relationship update tests passed successfully!")
def test_get_affected_parent_ids() -> None:
"""
Tests get_affected_parent_ids functionality by verifying:
1. IDs that are directly in the parent_types list are included
2. IDs that have children in the updated_ids list are included
3. IDs that are neither of the above are not included
"""
# Create test data with relationships
test_data = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Parent Account 2",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Not Affected Account",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Child",
"LastName": "Contact",
}
],
}
# Create and update CSV files for test data
for object_type, records in test_data.items():
create_csv_file(object_type, records)
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
assert (
_VALID_SALESFORCE_IDS[2] not in affected_ids
), "Unaffected ID incorrectly included"
# Test Case 4: No matches
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Opportunity"] # Wrong type
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 0, "Should return empty list when no matches"
print("All get_affected_parent_ids tests passed successfully!")
def main_build() -> None:
clear_sf_db()
create_csv_with_example_data()
test_query()
test_upsert()
test_relationships()
test_account_with_children()
test_relationship_updates()
test_get_affected_parent_ids()
if __name__ == "__main__":
main_build()

View File

@@ -0,0 +1,386 @@
import csv
import json
import os
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
@contextmanager
def get_db_connection(
isolation_level: str | None = None,
) -> Iterator[sqlite3.Connection]:
"""Get a database connection with proper isolation level and error handling.
Args:
isolation_level: SQLite isolation level. None = default "DEFERRED",
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
"""
# 60 second timeout for locks
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
if isolation_level is not None:
conn.isolation_level = isolation_level
try:
yield conn
except Exception:
conn.rollback()
raise
finally:
conn.close()
def init_db() -> None:
"""Initialize the SQLite database with required tables if they don't exist."""
if os.path.exists(get_sqlite_db_path()):
return
# Create database directory if it doesn't exist
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
with get_db_connection("EXCLUSIVE") as conn:
cursor = conn.cursor()
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
# Main table for storing Salesforce objects
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS salesforce_objects (
id TEXT PRIMARY KEY,
object_type TEXT NOT NULL,
data TEXT NOT NULL, -- JSON serialized data
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
) WITHOUT ROWID -- Optimize for primary key lookups
"""
)
# Table for parent-child relationships with covering index
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relationships (
child_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
PRIMARY KEY (child_id, parent_id)
) WITHOUT ROWID -- Optimize for primary key lookups
"""
)
# New table for caching parent-child relationships with object types
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relationship_types (
child_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
parent_type TEXT NOT NULL,
PRIMARY KEY (child_id, parent_id, parent_type)
) WITHOUT ROWID
"""
)
# Always recreate indexes to ensure they exist
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
# Create covering indexes for common queries
cursor.execute(
"""
CREATE INDEX idx_object_type
ON salesforce_objects(object_type, id)
WHERE object_type IS NOT NULL
"""
)
cursor.execute(
"""
CREATE INDEX idx_parent_id
ON relationships(parent_id, child_id)
"""
)
cursor.execute(
"""
CREATE INDEX idx_child_parent
ON relationships(child_id)
WHERE child_id IS NOT NULL
"""
)
# New composite index for fast parent type lookups
cursor.execute(
"""
CREATE INDEX idx_relationship_types_lookup
ON relationship_types(parent_type, child_id, parent_id)
"""
)
# Analyze tables to help query planner
cursor.execute("ANALYZE relationships")
cursor.execute("ANALYZE salesforce_objects")
cursor.execute("ANALYZE relationship_types")
conn.commit()
def _update_relationship_tables(
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
) -> None:
"""Update the relationship tables when a record is updated.
Args:
conn: The database connection to use (must be in a transaction)
child_id: The ID of the child record
parent_ids: Set of parent IDs to link to
"""
try:
cursor = conn.cursor()
# Get existing parent IDs
cursor.execute(
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
)
old_parent_ids = {row[0] for row in cursor.fetchall()}
# Calculate differences
parent_ids_to_remove = old_parent_ids - parent_ids
parent_ids_to_add = parent_ids - old_parent_ids
# Remove old relationships
if parent_ids_to_remove:
cursor.executemany(
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
[(child_id, pid) for pid in parent_ids_to_remove],
)
# Also remove from relationship_types
cursor.executemany(
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
[(child_id, pid) for pid in parent_ids_to_remove],
)
# Add new relationships
if parent_ids_to_add:
# First add to relationships table
cursor.executemany(
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
[(child_id, pid) for pid in parent_ids_to_add],
)
# Then get the types of the parent objects and add to relationship_types
for parent_id in parent_ids_to_add:
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?",
(parent_id,),
)
result = cursor.fetchone()
if result:
parent_type = result[0]
cursor.execute(
"""
INSERT INTO relationship_types (child_id, parent_id, parent_type)
VALUES (?, ?, ?)
""",
(child_id, parent_id, parent_type),
)
except Exception as e:
logger.error(f"Error updating relationship tables: {e}")
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
raise
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
"""Update the SF DB with a CSV file using SQLite storage."""
updated_ids = []
# Use IMMEDIATE to get a write lock at the start of the transaction
with get_db_connection("IMMEDIATE") as conn:
cursor = conn.cursor()
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
if "Id" not in row:
logger.warning(
f"Row {row} does not have an Id field in {csv_download_path}"
)
continue
id = row["Id"]
parent_ids = set()
field_to_remove: set[str] = set()
# Process relationships and clean data
for field, value in row.items():
if validate_salesforce_id(value) and field != "Id":
parent_ids.add(value)
field_to_remove.add(field)
if not value:
field_to_remove.add(field)
# Remove unwanted fields
for field in field_to_remove:
if field != "LastModifiedById":
del row[field]
# Update main object data
cursor.execute(
"""
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
VALUES (?, ?, ?)
""",
(id, object_type, json.dumps(row)),
)
# Update relationships using the same connection
_update_relationship_tables(conn, id, parent_ids)
updated_ids.append(id)
conn.commit()
return updated_ids
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID."""
with get_db_connection() as conn:
cursor = conn.cursor()
# Force index usage with INDEXED BY
cursor.execute(
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
(parent_id,),
)
child_ids = {row[0] for row in cursor.fetchall()}
return child_ids
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
)
result = cursor.fetchone()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
return result[0]
def get_record(
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""Retrieve the record and return it as a SalesforceObject."""
if object_type is None:
object_type = get_type_from_id(object_id)
if not object_type:
return None
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
result = cursor.fetchone()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
data = json.loads(result[0])
return SalesforceObject(id=object_id, type=object_type, data=data)
def find_ids_by_type(object_type: str) -> list[str]:
"""Find all object IDs for rows of the specified type."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
)
return [row[0] for row in cursor.fetchall()]
def get_affected_parent_ids_by_type(
updated_ids: list[str],
parent_types: list[str],
batch_size: int = 500,
) -> Iterator[tuple[str, set[str]]]:
"""Get IDs of objects that are of the specified parent types and are either in the
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
"""
# SQLite typically has a limit of 999 variables
updated_ids_batches = batch_list(updated_ids, batch_size)
updated_parent_ids: set[str] = set()
with get_db_connection() as conn:
cursor = conn.cursor()
for batch_ids in updated_ids_batches:
id_placeholders = ",".join(["?" for _ in batch_ids])
for parent_type in parent_types:
affected_ids: set[str] = set()
# Get directly updated objects of parent types - using index on object_type
cursor.execute(
f"""
SELECT id FROM salesforce_objects
WHERE id IN ({id_placeholders})
AND object_type = ?
""",
batch_ids + [parent_type],
)
affected_ids.update(row[0] for row in cursor.fetchall())
# Get parent objects of updated objects - using optimized relationship_types table
cursor.execute(
f"""
SELECT DISTINCT parent_id
FROM relationship_types
INDEXED BY idx_relationship_types_lookup
WHERE parent_type = ?
AND child_id IN ({id_placeholders})
""",
[parent_type] + batch_ids,
)
affected_ids.update(row[0] for row in cursor.fetchall())
# Remove any parent IDs that have already been processed
new_affected_ids = affected_ids - updated_parent_ids
# Add the new affected IDs to the set of updated parent IDs
updated_parent_ids.update(new_affected_ids)
if new_affected_ids:
yield parent_type, new_affected_ids
def has_at_least_one_object_of_type(object_type: str) -> bool:
"""Check if there is at least one object of the specified type in the database.
Args:
object_type: The Salesforce object type to check
Returns:
bool: True if at least one object exists, False otherwise
"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
(object_type,),
)
count = cursor.fetchone()[0]
return count > 0

View File

@@ -1,66 +1,72 @@
import re
from typing import Union
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
import os
from dataclasses import dataclass
from typing import Any
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
@dataclass
class SalesforceObject:
id: str
type: str
data: dict[str, Any]
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
if "__c" in key: # remove the custom object indicator for display
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
if filtered_value: # Only add non-empty dictionaries or lists
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
if filtered_item: # Only add non-empty dictionaries or lists
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def to_dict(self) -> dict[str, Any]:
return {
"ID": self.id,
"Type": self.type,
"Data": self.data,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
return cls(
id=data["Id"],
type=data["Type"],
data=data,
)
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent))
else:
result.append(f"{indent_str}{data}")
return "\n".join(result)
# This defines the base path for all data files relative to this file
# AKA BE CAREFUL WHEN MOVING THIS FILE
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
def extract_dict_text(raw_dict: dict) -> str:
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_dict = _json_to_natural_language(processed_dict)
return natural_language_dict
def get_sqlite_db_path() -> str:
"""Get the path to the sqlite db file."""
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
def get_object_type_path(object_type: str) -> str:
"""Get the directory path for a specific object type."""
type_dir = os.path.join(BASE_DATA_PATH, object_type)
os.makedirs(type_dir, exist_ok=True)
return type_dir
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
def validate_salesforce_id(salesforce_id: str) -> bool:
"""Validate the checksum portion of an 18-character Salesforce ID.
Args:
salesforce_id: An 18-character Salesforce ID
Returns:
bool: True if the checksum is valid, False otherwise
"""
if len(salesforce_id) != 18:
return False
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
checksum = salesforce_id[15:18]
calculated_checksum = ""
for chunk in chunks:
result_string = "".join(
"1" if char.isupper() else "0" for char in reversed(chunk)
)
calculated_checksum += _LOOKUP[result_string]
return checksum == calculated_checksum

View File

@@ -264,24 +264,6 @@ class SlackTextCleaner:
message = message.replace("<!everyone>", "@everyone")
return message
@staticmethod
def replace_links(message: str) -> str:
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
# Find user IDs in the message
possible_link_matches = re.findall(r"<(.*?)>", message)
for possible_link in possible_link_matches:
if not possible_link:
continue
# Special slack patterns that aren't for links
if possible_link[0] not in ["#", "@", "!"]:
link_display = (
possible_link
if "|" not in possible_link
else possible_link.split("|")[1]
)
message = message.replace(f"<{possible_link}>", link_display)
return message
@staticmethod
def replace_special_catchall(message: str) -> str:
"""Replaces pattern of <!something|another-thing> with another-thing

View File

@@ -33,6 +33,7 @@ from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
from onyx.utils.sitemap import list_pages_for_site
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -241,6 +242,12 @@ class WebConnector(LoadConnector):
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
# Explicitly check if running in multi-tenant mode to prevent potential security risks
if MULTI_TENANT:
raise ValueError(
"Upload input for web connector is not supported in cloud environments"
)
logger.warning(
"This is not a UI supported Web Connector flow, "
"are you sure you want to do this?"

View File

@@ -40,6 +40,13 @@ class ZendeskClient:
response = requests.get(
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
)
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
if retry_after is not None:
# Sleep for the duration indicated by the Retry-After header
time.sleep(int(retry_after))
response.raise_for_status()
return response.json()

View File

@@ -96,6 +96,8 @@ class Tag(BaseModel):
class BaseFilters(BaseModel):
source_type: list[DocumentSource] | None = None
document_set: list[str] | None = None
user_folders: list[str] | None = None
document_ids: list[str] | None = None
time_cutoff: datetime | None = None
tags: list[Tag] | None = None

View File

@@ -54,9 +54,11 @@ def get_total_users_count(db_session: Session) -> int:
return user_count + invited_users
async def get_user_count() -> int:
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)
user_count = result.scalar()
if user_count is None:

View File

@@ -141,14 +141,20 @@ def get_valid_messages_from_query_sessions(
return {row.chat_session_id: row.message for row in first_messages}
# Retrieves chat sessions by user
# Chat sessions do not include onyxbot flows
def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
include_onyxbot_flows: bool = False,
limit: int = 50,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
@@ -310,6 +316,23 @@ def update_chat_session(
return chat_session
def delete_all_chat_sessions_for_user(
user: User | None, db_session: Session, hard_delete: bool = HARD_DELETE_CHATS
) -> None:
user_id = user.id if user is not None else None
query = db_session.query(ChatSession).filter(
ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False)
)
if hard_delete:
query.delete(synchronize_session=False)
else:
query.update({ChatSession.deleted: True}, synchronize_session=False)
db_session.commit()
def delete_chat_session(
user_id: UUID | None,
chat_session_id: UUID,

View File

@@ -7,6 +7,7 @@ from sqlalchemy import exists
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
@@ -90,15 +91,22 @@ def get_connector_credential_pairs(
user: User | None = None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = _add_user_filters(stmt, user, get_editable)
if not include_disabled:
stmt = stmt.where(
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
) # noqa
)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
@@ -310,6 +318,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
if existing_association is not None:
return
# DefaultCCPair has id 1 since it is the first CC pair created
# It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the
# auto-incrementing id
association = ConnectorCredentialPair(
connector_id=0,
credential_id=0,
@@ -350,7 +361,12 @@ def add_credential_to_connector(
last_successful_index_time: datetime | None = None,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
@@ -427,7 +443,12 @@ def remove_credential_from_connector(
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")

View File

@@ -86,7 +86,7 @@ def _add_user_filters(
"""
Filter Credentials by:
- if the user is in the user_group that owns the Credential
- if the user is not a global_curator, they must also have a curator relationship
- if the user is a curator, they must also have a curator relationship
to the user_group
- if editing is being done, we also filter out Credentials that are owned by groups
that the user isn't a curator for
@@ -97,6 +97,7 @@ def _add_user_filters(
where_clause = User__UserGroup.user_id == user.id
if user.role == UserRole.CURATOR:
where_clause &= User__UserGroup.is_curator == True # noqa: E712
if get_editable:
user_groups = select(User__UserGroup.user_group_id).where(
User__UserGroup.user_id == user.id
@@ -152,10 +153,16 @@ def fetch_credential_by_id(
user: User | None,
db_session: Session,
assume_admin: bool = False,
get_editable: bool = True,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
stmt = _add_user_filters(stmt, user, assume_admin=assume_admin)
stmt = _add_user_filters(
stmt=stmt,
user=user,
assume_admin=assume_admin,
get_editable=get_editable,
)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential

View File

@@ -1,5 +1,7 @@
import contextlib
import os
import re
import ssl
import threading
import time
from collections.abc import AsyncGenerator
@@ -10,6 +12,8 @@ from datetime import datetime
from typing import Any
from typing import ContextManager
import asyncpg # type: ignore
import boto3
import jwt
from fastapi import HTTPException
from fastapi import Request
@@ -23,6 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@@ -37,6 +42,7 @@ from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -49,28 +55,87 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
# Global so we don't create more than one engine per process
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()
def get_iam_auth_token(
host: str, port: str, user: str, region: str = "us-east-2"
) -> str:
"""
Generate an IAM authentication token using boto3.
"""
client = boto3.client("rds", region_name=region)
token = client.generate_db_auth_token(
DBHostname=host, Port=int(port), DBUsername=user
)
return token
def configure_psycopg2_iam_auth(
cparams: dict[str, Any], host: str, port: str, user: str, region: str
) -> None:
"""
Configure cparams for psycopg2 with IAM token and SSL.
"""
token = get_iam_auth_token(host, port, user, region)
cparams["password"] = token
cparams["sslmode"] = "require"
cparams["sslrootcert"] = SSL_CERT_FILE
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
use_iam: bool = USE_IAM_AUTH,
region: str = "us-west-2",
) -> str:
if use_iam:
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
else:
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
# For asyncpg, do not include application_name in the connection string
if app_name and db_api != "asyncpg":
if "?" in base_conn_str:
return f"{base_conn_str}&application_name={app_name}"
else:
return f"{base_conn_str}?application_name={app_name}"
return base_conn_str
if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
conn.info["query_start_time"] = time.time()
# Function to log after query execution
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
total_time = time.time() - conn.info["query_start_time"]
# don't spam TOO hard
if total_time > 0.1:
logger.debug(
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
@@ -78,7 +143,6 @@ if LOG_POSTGRES_LATENCY:
if LOG_POSTGRES_CONN_COUNTS:
# Global counter for connection checkouts and checkins
checkout_count = 0
checkin_count = 0
@@ -105,21 +169,13 @@ if LOG_POSTGRES_CONN_COUNTS:
logger.debug(f"Total connection checkins: {checkin_count}")
"""END DEBUGGING LOGGING"""
def get_db_current_time(db_session: Session) -> datetime:
"""Get the current time from Postgres representing the start of the transaction
Within the same transaction this value will not update
This datetime object returned should be timezone aware, default Postgres timezone is UTC
"""
result = db_session.execute(text("SELECT NOW()")).scalar()
if result is None:
raise ValueError("Database did not return a time")
return result
# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
@@ -128,16 +184,9 @@ def is_valid_schema_name(name: str) -> bool:
class SqlEngine:
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
Will eventually subsume most of the standalone functions in this file.
Sync only for now.
"""
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
# Default parameters for engine creation
DEFAULT_ENGINE_KWARGS = {
"pool_size": 20,
"max_overflow": 5,
@@ -145,33 +194,27 @@ class SqlEngine:
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
def __init__(self) -> None:
pass
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
"""Private helper method to create and return an Engine."""
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
return create_engine(connection_string, **merged_kwargs)
engine = create_engine(connection_string, **merged_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
return engine
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
"""Allow the caller to init the engine with extra params. Different clients
such as the API server and different Celery workers and tasks
need different settings.
"""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
@classmethod
def get_engine(cls) -> Engine:
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!
"""
if not cls._engine:
with cls._lock:
if not cls._engine:
@@ -180,12 +223,10 @@ class SqlEngine:
@classmethod
def set_app_name(cls, app_name: str) -> None:
"""Class method to set the app name."""
cls._app_name = app_name
@classmethod
def get_app_name(cls) -> str:
"""Class method to get current app name."""
if not cls._app_name:
return ""
return cls._app_name
@@ -217,56 +258,71 @@ def get_all_tenant_ids() -> list[str] | list[None]:
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def get_sqlalchemy_engine() -> Engine:
return SqlEngine.get_engine()
async def get_async_connection() -> Any:
"""
Custom connection function for async engine when using IAM auth.
"""
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
db = POSTGRES_DB
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
# asyncpg requires 'ssl="require"' if SSL needed
return await asyncpg.connect(
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
)
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# Underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
app_name = SqlEngine.get_app_name() + "_async"
connection_string = build_connection_string(
db_api=ASYNC_DB_API,
use_iam=USE_IAM_AUTH,
)
connect_args: dict[str, Any] = {}
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = ssl_context
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args={
"server_settings": {
"application_name": SqlEngine.get_app_name() + "_async"
}
},
# async engine is only used by API server, so we can use those values
# here as well
connect_args=connect_args,
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
if USE_IAM_AUTH:
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
def provide_iam_token_async(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
# For async engine using asyncpg, we still need to set the IAM token here.
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
cparams["password"] = token
cparams["ssl"] = ssl_context
return _ASYNC_ENGINE
# Dependency to get the current tenant ID
# If no token is present, uses the default schema for this use case
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -275,7 +331,6 @@ def get_current_tenant_id(request: Request) -> str:
token = request.cookies.get("fastapiusersauth")
if not token:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no token is present, use the default schema or handle accordingly
return current_value
try:
@@ -289,7 +344,6 @@ def get_current_tenant_id(request: Request) -> str:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
return CURRENT_TENANT_ID_CONTEXTVAR.get()
@@ -316,7 +370,6 @@ async def get_async_session_with_tenant(
async with async_session_factory() as session:
try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
@@ -326,8 +379,6 @@ async def get_async_session_with_tenant(
)
except Exception:
logger.exception("Error setting search_path.")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
else:
yield session
@@ -335,9 +386,6 @@ async def get_async_session_with_tenant(
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@@ -349,7 +397,6 @@ def get_session_with_tenant(
) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
@@ -357,27 +404,20 @@ def get_session_with_tenant(
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
@@ -390,21 +430,17 @@ def get_session_with_tenant(
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
@@ -424,12 +460,9 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(
detail="User must authenticate",
)
raise BasicAuthenticationError(detail="User must authenticate")
engine = get_sqlalchemy_engine()
@@ -437,20 +470,17 @@ def get_session() -> Generator[Session, None, None]:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
@@ -461,7 +491,6 @@ def get_session_context_manager() -> ContextManager[Session]:
def get_session_factory() -> sessionmaker[Session]:
"""Get a session factory."""
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
@@ -489,3 +518,13 @@ async def warm_up_connections(
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
if USE_IAM_AUTH:
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
region = os.getenv("AWS_REGION_NAME", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)

View File

@@ -1,132 +0,0 @@
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.chat import delete_chat_session
from onyx.db.models import ChatFolder
from onyx.db.models import ChatSession
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_user_folders(
user_id: UUID | None,
db_session: Session,
) -> list[ChatFolder]:
return db_session.query(ChatFolder).filter(ChatFolder.user_id == user_id).all()
def update_folder_display_priority(
user_id: UUID | None,
display_priority_map: dict[int, int],
db_session: Session,
) -> None:
folders = get_user_folders(user_id=user_id, db_session=db_session)
folder_ids = {folder.id for folder in folders}
if folder_ids != set(display_priority_map.keys()):
raise ValueError("Invalid Folder IDs provided")
for folder in folders:
folder.display_priority = display_priority_map[folder.id]
db_session.commit()
def get_folder_by_id(
user_id: UUID | None,
folder_id: int,
db_session: Session,
) -> ChatFolder:
folder = (
db_session.query(ChatFolder).filter(ChatFolder.id == folder_id).one_or_none()
)
if not folder:
raise ValueError("Folder by specified id does not exist")
if folder.user_id != user_id:
raise PermissionError(f"Folder does not belong to user: {user_id}")
return folder
def create_folder(
user_id: UUID | None, folder_name: str | None, db_session: Session
) -> int:
new_folder = ChatFolder(
user_id=user_id,
name=folder_name,
)
db_session.add(new_folder)
db_session.commit()
return new_folder.id
def rename_folder(
user_id: UUID | None, folder_id: int, folder_name: str | None, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
folder.name = folder_name
db_session.commit()
def add_chat_to_folder(
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
chat_session.folder_id = folder.id
db_session.commit()
def remove_chat_from_folder(
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
if chat_session.folder_id != folder.id:
raise ValueError("The chat session is not in the specified folder.")
if folder.user_id != user_id:
raise ValueError(
f"Tried to remove a chat session from a folder that does not below to "
f"this user, user id: {user_id}"
)
chat_session.folder_id = None
if chat_session in folder.chat_sessions:
folder.chat_sessions.remove(chat_session)
db_session.commit()
def delete_folder(
user_id: UUID | None,
folder_id: int,
including_chats: bool,
db_session: Session,
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
# Assuming there will not be a massive number of chats in any given folder
if including_chats:
for chat_session in folder.chat_sessions:
delete_chat_session(
user_id=user_id,
chat_session_id=chat_session.id,
db_session=db_session,
)
db_session.delete(folder)
db_session.commit()

View File

@@ -0,0 +1,99 @@
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from onyx.configs.constants import MilestoneRecordType
from onyx.db.models import Milestone
from onyx.db.models import User
USER_ASSISTANT_PREFIX = "user_assistants_used_"
MULTI_ASSISTANT_USED = "multi_assistant_used"
def create_milestone(
user: User | None,
event_type: MilestoneRecordType,
db_session: Session,
) -> Milestone:
milestone = Milestone(
event_type=event_type,
user_id=user.id if user else None,
)
db_session.add(milestone)
db_session.commit()
return milestone
def create_milestone_if_not_exists(
user: User | None, event_type: MilestoneRecordType, db_session: Session
) -> tuple[Milestone, bool]:
# Check if it exists
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one_or_none()
if milestone is not None:
return milestone, False
# If it doesn't exist, try to create it.
try:
milestone = create_milestone(user, event_type, db_session)
return milestone, True
except IntegrityError:
# Another thread or process inserted it in the meantime
db_session.rollback()
# Fetch again to return the existing record
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one() # Now should exist
return milestone, False
def update_user_assistant_milestone(
milestone: Milestone,
user_id: str | None,
assistant_id: int,
db_session: Session,
) -> None:
event_tracker = milestone.event_tracker
if event_tracker is None:
milestone.event_tracker = event_tracker = {}
if event_tracker.get(MULTI_ASSISTANT_USED):
# No need to keep tracking and populating if the milestone has already been hit
return
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
if event_tracker.get(user_key) is None:
event_tracker[user_key] = [assistant_id]
elif assistant_id not in event_tracker[user_key]:
event_tracker[user_key].append(assistant_id)
flag_modified(milestone, "event_tracker")
db_session.commit()
def check_multi_assistant_milestone(
milestone: Milestone,
db_session: Session,
) -> tuple[bool, bool]:
"""Returns if the milestone was hit and if it was just hit for the first time"""
event_tracker = milestone.event_tracker
if event_tracker is None:
return False, False
if event_tracker.get(MULTI_ASSISTANT_USED):
return True, False
for key, value in event_tracker.items():
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
event_tracker[MULTI_ASSISTANT_USED] = True
flag_modified(milestone, "event_tracker")
db_session.commit()
return True, True
return False, False

View File

@@ -5,6 +5,8 @@ from typing import Literal
from typing import NotRequired
from typing import Optional
from uuid import uuid4
from pydantic import BaseModel
from typing_extensions import TypedDict # noreorder
from uuid import UUID
@@ -37,7 +39,7 @@ from sqlalchemy.types import TypeDecorator
from onyx.auth.schemas import UserRole
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
@@ -52,6 +54,7 @@ from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.enums import TaskStatus
from onyx.db.pydantic_type import PydanticType
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
@@ -63,6 +66,8 @@ from onyx.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider
logger = setup_logger()
class Base(DeclarativeBase):
__abstract__ = True
@@ -70,6 +75,8 @@ class Base(DeclarativeBase):
class EncryptedString(TypeDecorator):
impl = LargeBinary
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
def process_bind_param(self, value: str | None, dialect: Dialect) -> bytes | None:
if value is not None:
@@ -84,6 +91,8 @@ class EncryptedString(TypeDecorator):
class EncryptedJson(TypeDecorator):
impl = LargeBinary
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
def process_bind_param(self, value: dict | None, dialect: Dialect) -> bytes | None:
if value is not None:
@@ -100,11 +109,76 @@ class EncryptedJson(TypeDecorator):
return value
class NullFilteredString(TypeDecorator):
impl = String
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
if value is not None and "\x00" in value:
logger.warning(f"NUL characters found in value: {value}")
return value.replace("\x00", "")
return value
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
return value
"""
Auth/Authz (users, permissions, access) Tables
"""
class UserFolder(Base):
__tablename__ = "user_folder"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), nullable=False)
parent_id: Mapped[int | None] = mapped_column(
ForeignKey("user_folder.id"), nullable=True
)
name: Mapped[str] = mapped_column(nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
default=datetime.datetime.utcnow
)
user: Mapped["User"] = relationship(back_populates="folders")
parent: Mapped["UserFolder"] = relationship(
remote_side=[id], back_populates="children"
)
children: Mapped[list["UserFolder"]] = relationship(back_populates="parent")
files: Mapped[list["UserFile"]] = relationship(back_populates="folder")
chat_sessions: Mapped[list["ChatSession"]] = relationship(back_populates="folder")
class UserDocument(str, Enum):
CHAT = "chat"
RECENT = "recent"
FILE = "file"
class UserFile(Base):
__tablename__ = "user_file"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"), nullable=False)
parent_folder_id: Mapped[int | None] = mapped_column(
ForeignKey("user_folder.id"), nullable=True
)
file_id: Mapped[str] = mapped_column(nullable=False)
document_id: Mapped[str] = mapped_column(nullable=False)
name: Mapped[str] = mapped_column(nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
default=datetime.datetime.utcnow
)
ccpair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id"), nullable=False
)
user: Mapped["User"] = relationship(back_populates="files")
folder: Mapped["UserFolder"] = relationship(back_populates="files")
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
# even an almost empty token from keycloak will not fit the default 1024 bytes
access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore
@@ -154,9 +228,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
)
chat_folders: Mapped[list["ChatFolder"]] = relationship(
"ChatFolder", back_populates="user"
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
@@ -174,6 +245,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
folders: Mapped[list["UserFolder"]] = relationship(
"UserFolder", back_populates="user"
)
files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user")
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
pass
@@ -449,16 +525,16 @@ class Document(Base):
# this should correspond to the ID of the document
# (as is passed around in Onyx)
id: Mapped[str] = mapped_column(String, primary_key=True)
id: Mapped[str] = mapped_column(NullFilteredString, primary_key=True)
from_ingestion_api: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=True
)
# 0 for neutral, positive for mostly endorse, negative for mostly reject
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
semantic_id: Mapped[str] = mapped_column(String)
semantic_id: Mapped[str] = mapped_column(NullFilteredString)
# First Section's link
link: Mapped[str | None] = mapped_column(String, nullable=True)
link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
# The updated time is also used as a measure of the last successful state of the doc
# pulled from the source (to help skip reindexing already updated docs in case of
@@ -974,7 +1050,7 @@ class ChatSession(Base):
default=ChatSessionSharedStatus.PRIVATE,
)
folder_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_folder.id"), nullable=True
ForeignKey("user_folder.id", ondelete="SET NULL"), nullable=True
)
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
@@ -1004,11 +1080,11 @@ class ChatSession(Base):
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
folder: Mapped["UserFolder"] = relationship(
"UserFolder", back_populates="chat_sessions"
)
messages: Mapped[list["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session"
"ChatMessage", back_populates="chat_session", cascade="all, delete-orphan"
)
persona: Mapped["Persona"] = relationship("Persona")
@@ -1076,6 +1152,8 @@ class ChatMessage(Base):
"SearchDoc",
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
cascade="all, delete-orphan",
single_parent=True,
)
tool_call: Mapped["ToolCall"] = relationship(
@@ -1091,33 +1169,6 @@ class ChatMessage(Base):
)
class ChatFolder(Base):
"""For organizing chat sessions"""
__tablename__ = "chat_folder"
id: Mapped[int] = mapped_column(primary_key=True)
# Only null if auth is off
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str | None] = mapped_column(String, nullable=True)
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
user: Mapped[User] = relationship("User", back_populates="chat_folders")
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="folder"
)
def __lt__(self, other: Any) -> bool:
if not isinstance(other, ChatFolder):
return NotImplemented
if self.display_priority == other.display_priority:
# Bigger ID (created later) show earlier
return self.id > other.id
return self.display_priority < other.display_priority
"""
Feedback, Logging, Metrics Tables
"""
@@ -1344,6 +1395,11 @@ class StarterMessage(TypedDict):
message: str
class StarterMessageModel(BaseModel):
name: str
message: str
class Persona(Base):
__tablename__ = "persona"
@@ -1534,6 +1590,32 @@ class SlackBot(Base):
)
class Milestone(Base):
# This table is used to track significant events for a deployment towards finding value
# The table is currently not used for features but it may be used in the future to inform
# users about the product features and encourage usage/exploration.
__tablename__ = "milestone"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
event_type: Mapped[MilestoneRecordType] = mapped_column(String)
# Need to track counts and specific ids of certain events to know if the Milestone has been reached
event_tracker: Mapped[dict | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User | None] = relationship("User")
__table_args__ = (UniqueConstraint("event_type", name="uq_milestone_event_type"),)
class TaskQueueState(Base):
# Currently refers to Celery Tasks
__tablename__ = "task_queue_jobs"

View File

@@ -0,0 +1,36 @@
from typing import List
from fastapi import UploadFile
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.server.documents.connector import upload_files
from onyx.server.documents.models import FileUploadResponse
CHAT_FOLDER_ID = -1
RECENT_DOCUMENTS_FOLDER_ID = -2
def create_user_files(
files: List[UploadFile],
folder_id: int | None,
user: User | None,
db_session: Session,
) -> FileUploadResponse:
upload_response = upload_files(files, db_session)
for file_path, file in zip(upload_response.file_paths, files):
new_file = UserFile(
user_id=user.id if user else None,
parent_folder_id=folder_id,
file_id=file_path,
document_id=file_path, # We'll use the same ID for now
name=file.filename,
)
db_session.add(new_file)
db_session.commit()
return upload_response
# def trigger_document_indexing(db_session: Session, user_id: int) -> None:

View File

@@ -543,6 +543,10 @@ def upsert_persona(
if tools is not None:
existing_persona.tools = tools or []
# We should only update display priority if it is not already set
if existing_persona.display_priority is None:
existing_persona.display_priority = display_priority
persona = existing_persona
else:

View File

@@ -0,0 +1,29 @@
from typing import List
from fastapi import UploadFile
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.server.documents.connector import upload_files
from onyx.server.documents.models import FileUploadResponse
def create_user_files(
files: List[UploadFile],
folder_id: int | None,
user: User,
db_session: Session,
) -> FileUploadResponse:
upload_response = upload_files(files, db_session)
for file_path, file in zip(upload_response.file_paths, files):
new_file = UserFile(
user_id=user.id if user else None,
parent_folder_id=folder_id if folder_id != -1 else None,
file_id=file_path,
document_id=file_path,
name=file.filename,
)
db_session.add(new_file)
db_session.commit()
return upload_response

View File

View File

@@ -7,8 +7,15 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
@@ -185,3 +192,43 @@ def batch_add_ext_perm_user_if_not_exists(
db_session.commit()
return found_users + new_users
def delete_user_from_db(
user_to_delete: User,
db_session: Session,
) -> None:
for oauth_account in user_to_delete.oauth_accounts:
db_session.delete(oauth_account)
fetch_ee_implementation_or_noop(
"onyx.db.external_perm",
"delete_user__ext_group_for_user__no_commit",
)(
db_session=db_session,
user_id=user_to_delete.id,
)
db_session.query(SamlAccount).filter(
SamlAccount.user_id == user_to_delete.id
).delete()
db_session.query(DocumentSet__User).filter(
DocumentSet__User.user_id == user_to_delete.id
).delete()
db_session.query(Persona__User).filter(
Persona__User.user_id == user_to_delete.id
).delete()
db_session.query(User__UserGroup).filter(
User__UserGroup.user_id == user_to_delete.id
).delete()
db_session.delete(user_to_delete)
db_session.commit()
# NOTE: edge case may exist with race conditions
# with this `invited user` scheme generally.
user_emails = get_invited_users()
remaining_users = [
remaining_user_email
for remaining_user_email in user_emails
if remaining_user_email != user_to_delete.email
]
write_invited_users(remaining_users)

View File

@@ -369,6 +369,19 @@ class AdminCapable(abc.ABC):
raise NotImplementedError
class RandomCapable(abc.ABC):
"""Class must implement random document retrieval capability"""
@abc.abstractmethod
def random_retrieval(
self,
filters: IndexFilters,
num_to_retrieve: int = 10,
) -> list[InferenceChunkUncleaned]:
"""Retrieve random chunks matching the filters"""
raise NotImplementedError
class BaseIndex(
Verifiable,
Indexable,
@@ -376,6 +389,7 @@ class BaseIndex(
Deletable,
AdminCapable,
IdRetrievalCapable,
RandomCapable,
abc.ABC,
):
"""

View File

@@ -112,6 +112,11 @@ schema DANSWER_CHUNK_NAME {
rank: filter
attribute: fast-search
}
field user_folders type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must
@@ -218,4 +223,10 @@ schema DANSWER_CHUNK_NAME {
expression: bm25(content) + (5 * bm25(title))
}
}
rank-profile random_ {
first-phase {
expression: random.match
}
}
}

View File

@@ -23,7 +23,7 @@
<resource-limits>
<!-- Default is 75% but this can be increased for Dockerized deployments -->
<!-- https://docs.vespa.ai/en/operations/feed-block.html -->
<disk>0.75</disk>
<disk>0.85</disk>
</resource-limits>
</tuning>
<engine>

View File

@@ -16,6 +16,12 @@ logger = setup_logger()
CONTENT_SUMMARY = "content_summary"
@retry(tries=10, delay=1, backoff=2)
def _retryable_http_delete(http_client: httpx.Client, url: str) -> None:
res = http_client.delete(url)
res.raise_for_status()
@retry(tries=3, delay=1, backoff=2)
def _delete_vespa_doc_chunks(
document_id: str, index_name: str, http_client: httpx.Client
@@ -28,10 +34,10 @@ def _delete_vespa_doc_chunks(
for chunk_id in doc_chunk_ids:
try:
res = http_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}"
_retryable_http_delete(
http_client,
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}",
)
res.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"Failed to delete chunk, details: {e.response.text}")
raise

View File

@@ -2,6 +2,7 @@ import concurrent.futures
import io
import logging
import os
import random
import re
import time
import urllib
@@ -312,6 +313,7 @@ class VespaIndex(DocumentIndex):
with updating the associated permissions. Assumes that a document will not be split into
multiple chunk batches calling this function multiple times, otherwise only the last set of
chunks will be kept"""
# IMPORTANT: This must be done one index at a time, do not use secondary index here
cleaned_chunks = [clean_chunk_id_copy(chunk) for chunk in chunks]
@@ -534,7 +536,7 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with get_vespa_http_client() as http_client:
with get_vespa_http_client(http2=False) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
@@ -545,8 +547,12 @@ class VespaIndex(DocumentIndex):
while True:
try:
vespa_url = (
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}"
)
logger.debug(f'update_single PUT on URL "{vespa_url}"')
resp = http_client.put(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}",
vespa_url,
params=params,
headers={"Content-Type": "application/json"},
json=update_dict,
@@ -618,7 +624,7 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with get_vespa_http_client() as http_client:
with get_vespa_http_client(http2=False) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
@@ -629,8 +635,12 @@ class VespaIndex(DocumentIndex):
while True:
try:
vespa_url = (
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}"
)
logger.debug(f'delete_single DELETE on URL "{vespa_url}"')
resp = http_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
vespa_url,
params=params,
)
resp.raise_for_status()
@@ -697,6 +707,8 @@ class VespaIndex(DocumentIndex):
offset: int = 0,
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
) -> list[InferenceChunkUncleaned]:
print("filters", filters)
print("filters.user_folders", filters.__dict__)
vespa_where_clauses = build_vespa_filters(filters)
# Needs to be at least as much as the value set in Vespa schema config
target_hits = max(10 * num_to_retrieve, 1000)
@@ -903,6 +915,32 @@ class VespaIndex(DocumentIndex):
logger.info("Batch deletion completed")
def random_retrieval(
self,
filters: IndexFilters,
num_to_retrieve: int = 10,
) -> list[InferenceChunkUncleaned]:
"""Retrieve random chunks matching the filters using Vespa's random ranking
This method is currently used for random chunk retrieval in the context of
assistant starter message creation (passed as sample context for usage by the assistant).
"""
vespa_where_clauses = build_vespa_filters(filters, remove_trailing_and=True)
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
random_seed = random.randint(0, 1000000)
params: dict[str, str | int | float] = {
"yql": yql,
"hits": num_to_retrieve,
"timeout": VESPA_TIMEOUT,
"ranking.profile": "random_",
"ranking.properties.random.seed": random_seed,
}
return query_vespa(params)
class _VespaDeleteRequest:
def __init__(self, document_id: str, index_name: str) -> None:

View File

@@ -64,10 +64,10 @@ def _does_document_exist(
if doc_fetch_response.status_code != 200:
logger.debug(f"Failed to check for document with URL {doc_url}")
raise RuntimeError(
f"Unexpected fetch document by ID value from Vespa "
f"with error {doc_fetch_response.status_code}"
f"Index name: {index_name}"
f"Doc chunk id: {doc_chunk_id}"
f"Unexpected fetch document by ID value from Vespa: "
f"error={doc_fetch_response.status_code} "
f"index={index_name} "
f"doc_chunk_id={doc_chunk_id}"
)
return True

View File

@@ -55,7 +55,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
return _illegal_xml_chars_RE.sub("", text)
def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
"""
Configure and return an HTTP client for communicating with Vespa,
including authentication if needed.
@@ -67,5 +67,5 @@ def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
else None,
verify=False if not MANAGED_VESPA else True,
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
http2=True,
http2=http2,
)

View File

@@ -9,17 +9,24 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_IDS
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_FOLDERS
from onyx.utils.logger import setup_logger
logger = setup_logger()
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
def build_vespa_filters(
filters: IndexFilters,
*,
include_hidden: bool = False,
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
) -> str:
def _build_or_filters(key: str, vals: list[str] | None) -> str:
if vals is None:
return ""
@@ -72,12 +79,20 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
tags = filters.tags
if tags:
tag_attributes = [tag.tag_key + INDEX_SEPARATOR + tag.tag_value for tag in tags]
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
filter_str += _build_or_filters(USER_FOLDERS, filters.user_folders)
filter_str += _build_or_filters(DOCUMENT_IDS, filters.document_ids)
filter_str += _build_time_filter(filters.time_cutoff)
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5] # We remove the trailing " and "
return filter_str

View File

@@ -64,6 +64,8 @@ EMBEDDINGS = "embeddings"
TITLE_EMBEDDING = "title_embedding"
ACCESS_CONTROL_LIST = "access_control_list"
DOCUMENT_SETS = "document_sets"
USER_FOLDERS = "user_folders"
DOCUMENT_IDS = "document_ids"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"

View File

@@ -260,6 +260,21 @@ def index_doc_batch_prepare(
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
# Remove any NUL characters from title/semantic_id
# This is a known issue with the Zendesk connector
# Postgres cannot handle NUL characters in text fields
if document.title:
document.title = document.title.replace("\x00", "")
if document.semantic_identifier:
document.semantic_identifier = document.semantic_identifier.replace(
"\x00", ""
)
# Remove NUL characters from all sections
for section in document.sections:
if section.text is not None:
section.text = section.text.replace("\x00", "")
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())

View File

@@ -266,18 +266,27 @@ class DefaultMultiLLM(LLM):
# )
self._custom_config = custom_config
# Create a dictionary for model-specific arguments if it's None
model_kwargs = model_kwargs or {}
# NOTE: have to set these as environment variables for Litellm since
# not all are able to passed in but they always support them set as env
# variables. We'll also try passing them in, since litellm just ignores
# addtional kwargs (and some kwargs MUST be passed in rather than set as
# env variables)
if custom_config:
for k, v in custom_config.items():
os.environ[k] = v
# Specifically pass in "vertex_credentials" as a model_kwarg to the
# completion call for vertex AI. More details here:
# https://docs.litellm.ai/docs/providers/vertex
vertex_credentials_key = "vertex_credentials"
vertex_credentials = custom_config.get(vertex_credentials_key)
if vertex_credentials and model_provider == "vertex_ai":
model_kwargs[vertex_credentials_key] = vertex_credentials
else:
# standard case
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs = model_kwargs or {}
if custom_config:
model_kwargs.update(custom_config)
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:
@@ -453,7 +462,9 @@ class DefaultMultiLLM(LLM):
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
if (
DISABLE_LITELLM_STREAMING or self.config.model_name == "o1-2024-12-17"
): # TODO: remove once litellm supports streaming
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return

View File

@@ -29,6 +29,7 @@ OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"o1-mini",
"o1-preview",
"o1-2024-12-17",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",

View File

@@ -28,6 +28,7 @@ from litellm.exceptions import RateLimitError # type: ignore
from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.constants import MessageType
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -45,10 +46,19 @@ logger = setup_logger()
def litellm_exception_to_error_msg(
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
e: Exception,
llm: LLM,
fallback_to_error_msg: bool = False,
custom_error_msg_mappings: dict[str, str]
| None = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
) -> str:
error_msg = str(e)
if custom_error_msg_mappings:
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
if error_msg_pattern in error_msg:
return custom_error_msg
if isinstance(e, BadRequestError):
error_msg = "Bad request: The server couldn't process your request. Please check your input."
elif isinstance(e, AuthenticationError):

Some files were not shown because too many files have changed in this diff Show More