Compare commits

...

98 Commits

Author SHA1 Message Date
2cedaa1537 merge upstream 2025-08-18 08:40:20 +00:00
SubashMohan
fe029eccae chore: add SharePoint sync environment variables to integration test (#5197)
* chore: add SharePoint sync environment variables to integration test workflows

* fix cubic comments

* test: skip SharePoint permission tests for non-enterprise

* test: update SharePoint permission tests to skip for non-enterprise environments
2025-08-18 03:21:04 +00:00
Wenxi
ea72af7698 fix sharepoint tests (#5209) 2025-08-17 22:25:47 +00:00
6f6d7277f6 update local env file with template content
Signed-off-by: chris <chris@regan.co.nz>
2025-08-17 06:49:55 +00:00
c2a84dd231 add local env file
Signed-off-by: chris <chris@regan.co.nz>
2025-08-17 06:49:10 +00:00
Wenxi
17abf85533 fix unpaused user files (#5205) 2025-08-16 01:39:16 +00:00
Wenxi
3bd162acb9 fix: sharepoint tests and indexing logic (#5204)
* don't index onedrive personal sites in sharepoint

* fix sharepoint tests and indexing behavior

* remove print
2025-08-15 18:19:42 -07:00
Evan Lohn
664ce441eb generous timeout between docfetching finishing and docprocessing starting (#5201) 2025-08-15 15:43:01 -07:00
Wenxi
6863fbee54 fix: validate sharepoint connector with validate_connector_settings (#5199)
* validate sharepoint connector with validate_connector_settings

* fix test

* fix tests
2025-08-15 00:38:31 +00:00
Justin Tahara
bb98088b80 fix(infra): Fix Helm Chart Test (#5198) 2025-08-14 23:28:17 +00:00
Justin Tahara
ce8cb1112a feat(infra): Adding new AWS Terraform Template Code (#5194)
* feat(infra): Adding new AWS Terraform Template Code

* Addressing greptile comments

* Applying some updates after the cubic reviews as well

* Adding one detail

* Removing unused var

* Addressing more cubic comments
2025-08-14 16:47:15 -07:00
Nils
a605bd4ca4 feat: make sharepoint documents and sharepoint pages optional (#5183)
* feat: make sharepoint documents and sharepoint pages optional

* fix: address review feedback for PR #5183

* fix: exclude personal sites from sharepoint connector

---------

Co-authored-by: Nils Kleinrahm <nils.kleinrahm@pledoc.de>
2025-08-14 15:17:23 -07:00
Dominic Feliton
0e8b5af619 fix(connector): user file helm start cmd + legacy file connector incompatibility (#5195)
* Fix user file helm start cmd + legacy file connector incompatibility

* typo

* remove unnecessary logic

* undo

* make recommended changes

* keep comment

* cleanup

* format

---------

Co-authored-by: Dominic Feliton <37809476+dominicfeliton@users.noreply.github.com>
2025-08-14 13:20:19 -07:00
SubashMohan
46f3af4f68 enhance file processing with content type handling (#5196) 2025-08-14 08:59:53 +00:00
Evan Lohn
2af64ebf4c fix: ensure exception strings don't get swallowed (#5192)
* ensure exception strings don't get swallowed

* just send exception code
2025-08-13 20:05:16 +00:00
Evan Lohn
0eb1824158 fix: sf connector docs (#5171)
* fix: sf connector docs

* more sf logs

* better logs and new attempt

* add fields to error temporarily

* fix sf

---------

Co-authored-by: Wenxi <wenxi@onyx.app>
2025-08-13 17:52:32 +00:00
Chris Weaver
e0a9a6fb66 feat: okta profile tool (#5184)
* Initial Okta profile tool

* Improve

* Fix

* Improve

* Improve

* Address EL comments
2025-08-13 09:57:31 -07:00
Wenxi
fe194076c2 make default personas hideable (#5190) 2025-08-13 01:12:51 +00:00
Wenxi
55dc24fd27 fix: seeded total doc count (#5188)
* fix seeded total doc count

* fix seeded total doc count
2025-08-13 00:19:06 +00:00
Evan Lohn
da02962a67 fix: thread safe approach to docprocessing logging (#5185)
* thread safe approach to docprocessing logging

* unify approaches

* reset
2025-08-12 02:25:47 +00:00
SubashMohan
9bc62cc803 feat: sharepoint perm sync (#5033)
* sharepoint perm sync first draft

* feat: Implement SharePoint permission synchronization

* mypy fix

* remove commented code

* bot comments fixes and job failure fixes

* introduce generic way to upload certificates in credentials

* mypy fix

* add checkpoiting to sharepoint connector

* add sharepoint integration tests

* Refactor SharePoint connector to derive tenant domain from verified domains and remove direct tenant domain input from credentials

* address review comments

* add permission sync to site pages

* mypy fix

* fix tests error

* fix tests and address comments

* Update file extraction behavior in SharePoint connector to continue processing on unprocessable files
2025-08-11 16:59:16 +00:00
Evan Lohn
bf6705a9a5 fix: max tokens param (#5174)
* max tokens param

* fix unit test

* fix unit test
2025-08-11 09:57:44 -07:00
Rei Meguro
df2fef3383 fix: removal of old tags + is_list differentiation (#5147)
* initial migration

* getting metadata from tags

* complete migration

* migration override for cloud

* fix: more robust structured tag gen

* tag and indexing update

* fix: move is_list to tags

* migration rebase

* test cases + bugfix on unique constraint

* fix logging
2025-08-10 22:39:33 +00:00
SubashMohan
8cec3448d7 fix: restrict user file access to current user only (#5177)
* fix: restrict user file access to current user only

* fix: enhance user file access control for recent folder
2025-08-10 19:00:18 +00:00
Justin Tahara
b81687995e fix(infra): Removing invalid helm version (#5176) 2025-08-08 18:40:55 -07:00
Justin Tahara
87c2253451 fix(infra): Update github workflow to not tag latest (#5172)
* fix(infra): Update github workflow to not tag latest

* Cleaned up the code a bit
2025-08-08 23:23:55 +00:00
Wenxi
297c2957b4 add gpt 5 display names (#5175) 2025-08-08 16:58:47 -07:00
Wenxi
bacee0d09d fix: sanitize slack payload before logging (#5167)
* sanitize slack payload before logging

* nit
2025-08-08 02:10:00 +00:00
Evan Lohn
297720c132 refactor: file processing (#5136)
* file processing refactor

* mypy

* CW comments

* address CW
2025-08-08 00:34:35 +00:00
Evan Lohn
bd4bd00cef feat: office parsing markitdown (#5115)
* switch to markitdown untested

* passing tests

* reset file

* dotenv version

* docs

* add test file

* add doc

* fix integration test
2025-08-07 23:26:02 +00:00
Chris Weaver
07c482f727 Make starter messages visible on smaller screens (#5170) 2025-08-07 16:49:18 -07:00
Wenxi
cf193dee29 feat: support gpt5 models (#5169)
* support gpt5 models

* gpt5mini visible
2025-08-07 12:35:46 -07:00
Evan Lohn
1b47fa2700 fix: remove erroneous error case and add valid error (#5163)
* fix: remove erroneous error case and add valid error

* also address docfetching-docprocessing limbo
2025-08-07 18:17:00 +00:00
Wenxi Onyx
e1a305d18a mask llm api key from logs 2025-08-07 00:01:29 -07:00
Evan Lohn
e2233d22c9 feat: salesforce custom query (#5158)
* WIP merged approach untested

* tested custom configs

* JT comments

* fix unit test

* CW comments

* fix unit test
2025-08-07 02:37:23 +00:00
Justin Tahara
20d1175312 feat(infra): Bump Vespa Helm Version (#5161)
* feat(infra): Bump Vespa Helm Version

* Adding the Chart.lock file
2025-08-06 19:06:18 -07:00
justin-tahara
7117774287 Revert that change. Let's do this properly 2025-08-06 18:54:21 -07:00
justin-tahara
77f2660bb2 feat(infra): Update Vespa Helm Chart Version 2025-08-06 18:53:02 -07:00
Wenxi
1b2f4f3b87 fix: slash command slackbot to respond in private msg (#5151)
* fix slash command slackbot to respond in private msg

* rename confusing variable. fix slash message response in DMs
2025-08-05 19:03:38 -07:00
Evan Lohn
d85b55a9d2 no more scheduled stalling (#5154) 2025-08-05 20:17:44 +00:00
Justin Tahara
e2bae5a2d9 fix(infra): Adding helm directory (#5156)
* feat(infra): Adding helm directory

* one more fix
2025-08-05 14:11:57 -07:00
Justin Tahara
cc9c76c4fb feat(infra): Release Charts on Github Pages (#5155) 2025-08-05 14:03:28 -07:00
Chris Weaver
258e08abcd feat: add customization via env vars for curator role (#5150)
* Add customization via env vars for curator role

* Simplify

* Simplify more

* Address comments
2025-08-05 09:58:36 -07:00
Evan Lohn
67047e42a7 fix: preserve error traces (#5152) 2025-08-05 09:44:55 -07:00
SubashMohan
146628e734 fix unsupported character error in minio migration (#5145)
* fix unsupported character error in minio migration

* slash fix
2025-08-04 12:42:07 -07:00
Wenxi
c1d4b08132 fix: minio file names (#5138)
* nit var clarity

* maintain file names in connector config for display

* remove unused util

* migration draft

* optional file names to not break existing instances

* backwards compatible

* backwards compatible

* migration logging

* update file ocnn tests

* unncessary none

* mypy + explanatory comments
2025-08-01 20:31:29 +00:00
Justin Tahara
f3f47d0709 feat(infra): Creating new helm chart action workflow (#5137)
* feat(infra) Creating new helm chart action workflow

* Adding the steps

* Adding in dependencies

* One more debug

* Adding a new step to install helm
2025-08-01 09:26:58 -07:00
Justin Tahara
fe26a1bfcc feat(infra): Codeowner for Helm directory (#5139) 2025-07-31 23:05:46 +00:00
Wenxi
554cd0f891 fix: accept multiple zip types and fallback to extension (#5135)
* accept multiple zip types and fallback to extension

* move zip check to util

* mypy nit
2025-07-30 22:21:16 +00:00
Raunak Bhagat
f87d3e9849 fix: Make ungrounded types have a default name when sending to the frontend (#5133)
* Update names in map-comprehension

* Make default name for ungrounded types public

* Return the default name for ungrounded entity-types

* Update backend/onyx/db/entities.py

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2025-07-30 20:46:30 +00:00
Rei Meguro
72cdada893 edit link to custom actions (#5129) 2025-07-30 15:08:39 +00:00
SubashMohan
c442ebaff6 Feature/GitHub permission sync (#4996)
* github perm sync initial draft

* introduce github  doc sync and perm sync

* remove specific start time check

* Refactor GitHub connector to use SlimCheckpointOutputWrapper for improved document handling

* Update GitHub sync frequency defaults from 30 minutes to 5 minutes

* Add stop signal handling and progress reporting in GitHub document sync

* Refactor tests for Confluence and Google Drive connectors to use a mock fetch function for document access

* change the doc_sync approach

* add static typing for ocument columns and where clause

* remove prefix logic in connector runner

* mypy fix

* code review changes

* mypy fix

* fix review comments

* add sort order

* Implement merge heads migration for Alembic and update Confluence and Google Drive test

* github unit tests fix

* delete merge head and rebase the docmetadata field migration

---------

Co-authored-by: Subash <subash@onyx.app>
2025-07-30 02:42:18 +00:00
Justin Tahara
56f16d107e feat(infra): Update helm version after new feature (#5120) 2025-07-29 16:31:35 -07:00
Justin Tahara
0157ae099a [Vespa] Update to optimized configuration pt.2 (#5113) 2025-07-28 20:42:31 +00:00
justin-tahara
565fb42457 Let's do this properly 2025-07-28 10:42:31 -07:00
justin-tahara
a50a8b4a12 [Vespa] Update to optimized configuration 2025-07-28 10:38:48 -07:00
Evan Lohn
4baf4e7d96 feat: pruning freq (#5097)
* pruning frequency increase

* add logs
2025-07-26 22:29:43 +00:00
Wenxi
8b7ab2eb66 onyx metadata minio fix + permissive unstructured fail (#5085) 2025-07-25 21:26:02 +00:00
Evan Lohn
1f75f3633e fix: sidebar ranges (#5084) 2025-07-25 19:46:47 +00:00
Evan Lohn
650884d76a fix: preserve error traces (#5083) 2025-07-25 18:56:11 +00:00
Wenxi
8722bdb414 typo (#5082) 2025-07-25 18:26:21 +00:00
Evan Lohn
71037678c3 attempt to fix parsing of tricky template files (#5080) 2025-07-25 02:18:35 +00:00
Chris Weaver
68de1015e1 feat: support aspx files (#5068)
* Support aspx files

* Add fetching of site pages

* Improve

* Small enhancement

* more improvements

* Improvements

* Fix tests
2025-07-24 19:19:24 -07:00
Evan Lohn
e2b3a6e144 fix: drive external links (#5079) 2025-07-24 17:42:12 -07:00
Evan Lohn
4f04b09efa add library to fall back to for tokenizing (#5078) 2025-07-24 11:15:07 -07:00
SubashMohan
5c4f44d258 fix: sharepoint lg files issue (#5065)
* add SharePoint file size threshold check

* Implement retry logic for SharePoint queries to handle rate limiting and server error

* mypy fix

* add content none check

* remove unreachable code from retry logic in sharepoint connector
2025-07-24 14:26:01 +00:00
Evan Lohn
19652ad60e attempt fix for broken excel files (#5071) 2025-07-24 01:21:13 +00:00
Evan Lohn
70c96b6ab3 fix: remove locks from indexing callback (#5070) 2025-07-23 23:05:35 +00:00
Raunak Bhagat
65076b916f refactor: Update location of sidebar (#5067)
* Use props instead of inline type def

* Add new AppProvider

* Remove unused component file

* Move `sessionSidebar` to be inside of `components` instead of `app/chat`

* Change name of `sessionSidebar` to `sidebar`

* Remove `AppModeProvider`

* Fix bug in how the cookies were set
2025-07-23 21:59:34 +00:00
PaulHLiatrio
06bc0e51db fix: adjust template variable from .Chart.AppVersion to .Values.global.version to match versioning pattern. (#5069) 2025-07-23 14:54:32 -07:00
Devin
508b456b40 fix: explicit api_server dependency on minio in docker compose files (#5066) 2025-07-23 13:37:42 -07:00
Evan Lohn
bf1e2a2661 feat: avoid full rerun (#5063)
* fix: remove extra group sync

* second extra task

* minor improvement for non-checkpointed connectors
2025-07-23 18:01:23 +00:00
Evan Lohn
991d5e4203 fix: regen api key (#5064) 2025-07-23 03:36:51 +00:00
Evan Lohn
d21f012b04 fix: remove extra group sync (#5061)
* fix: remove extra group sync

* second extra task
2025-07-22 23:24:42 +00:00
Wenxi
86b7beab01 fix: too many internet chunks (#5060)
* minor internet search env vars

* add limit to internet search chunks

* note

* nits
2025-07-22 23:11:10 +00:00
Evan Lohn
b4eaa81d8b handle empty doc batches (#5058) 2025-07-22 22:35:59 +00:00
Evan Lohn
ff2a4c8723 fix: time discrepancy (#5056)
* fix time discrepancy

* remove log

* remove log
2025-07-22 22:19:02 +00:00
Raunak Bhagat
51027fd259 fix: Make pr-labeler run on edits too 2025-07-22 15:04:37 -07:00
Raunak Bhagat
7e3fd2b12a refactor: Update the error message that is logged when PR title fails Conventional Commits regex (#5062) 2025-07-22 14:46:22 -07:00
Chris Weaver
d2fef6f0b7 Tiny launch.json template improvement (#5055) 2025-07-22 11:15:44 -07:00
Evan Lohn
bd06147d26 feat: connector indexing decoupling (#4893)
* WIP

* renamed and moved tasks (WIP)

* minio migration

* bug fixes and finally add document batch storage

* WIP: can suceed but status is error

* WIP

* import fixes

* working v1 of decoupled

* catastrophe handling

* refactor

* remove unused db session in prep for new approach

* renaming and docstrings (untested)

* renames

* WIP with no more indexing fences

* robustness improvements

* clean up rebase

* migration and salesforce rate limits

* minor tweaks

* test fix

* connector pausing behavior

* correct checkpoint resumption logic

* cleanups in docfetching

* add heartbeat file

* update template jsonc

* deployment fixes

* fix vespa httpx pool

* error handling

* cosmetic fixes

* dumb

* logging improvements and non checkpointed connector fixes

* didnt save

* misc fixes

* fix import

* fix deletion of old files

* add in attempt prefix

* fix attempt prefix

* tiny log improvement

* minor changes

* fixed resumption behavior

* passing int tests

* fix unit test

* fixed unit tests

* trying timeout bump to see if int tests pass

* trying timeout bump to see if int tests pass

* fix autodiscovery

* helm chart fixes

* helm and logging
2025-07-22 03:33:25 +00:00
Raunak Bhagat
1f3cc9ed6e Make from_.user optional (use "Unknown User") if not found (#5051) 2025-07-21 17:50:28 -07:00
Raunak Bhagat
6086d9e51a feat: Updated KG admin page (#5044)
* Update KG admin UI

* Styling changes

* More changes

* Make edits auto-save

* Add more stylings / transitions

* Fix opacity

* Separate out modal into new component

* Revert backend changes

* Update styling

* Add convenience / styling changes to date-picker

* More styling / functional updates to kg admin-page

* Avoid reducing opacity of active-toggle

* Update backend APIs for new KG admin page

* More updates of styling for kg-admin page

* Remove nullability

* Remove console log

* Remove unused imports

* Change type of `children` variable

* Update web/src/app/admin/kg/interfaces.ts

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* Update web/src/components/CollapsibleCard.tsx

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* Remove null

* Update web/src/components/CollapsibleCard.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Force non-null

* Fix failing test

---------

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-07-21 15:37:27 -07:00
Raunak Bhagat
e0de24f64e Remove empty tooltip (#5050) 2025-07-21 12:45:48 -07:00
Rei Meguro
08b6b1f8b3 feat: Search and Answer Quality Test Script (#4974)
* aefads

* search quality tests improvement

Co-authored-by: wenxi-onyx <wenxi@onyx.app>

* nits

* refactor: config refactor

* document context + skip genai fix

* feat: answer eval

* more error messages

* mypy ragas

* mypy

* small fixes

* feat: more metrics

* fix

* feat: grab content

* typing

* feat: lazy updates

* mypy

* all at front

* feat: answer correctness

* use api key so it works with auth enabled

* update readme

* feat: auto add path

* feat: rate limit

* fix: readme + remove rerank all

* fix: raise exception immediately

* docs: improved clarity

* feat: federated handling

* fix: mypy

* nits

---------

Co-authored-by: wenxi-onyx <wenxi@onyx.app>
2025-07-19 01:51:51 +00:00
joachim-danswer
afed1a4b37 feat: KG improvements (#5048)
* improvements

* drop views if SQL fails

* mypy fix
2025-07-18 16:15:11 -07:00
Chris Weaver
bca18cacdf fix: improve assistant fetching efficiency (#5047)
* Improve assistant fetching efficiency

* More fix

* Fix weird build stuff

* Improve
2025-07-18 14:16:10 -07:00
Chris Weaver
335db91803 fix: improve check for indexing status (#5042)
* Improve check_for_indexing + check_for_vespa_sync_task

* Remove unused

* Fix

* Simplify query

* Add more logging

* Address bot comments

* Increase # of tasks generated since we're not going cc-pair by cc-pair

* Only index 50 user files at a time
2025-07-17 23:52:51 -07:00
Chris Weaver
67c488ff1f Improve support for non-default postgres schemas (#5046) 2025-07-17 23:51:39 -07:00
Wenxi
deb7f13962 remove chat session necessity from send message simple api (#5040) 2025-07-17 23:23:46 +00:00
Raunak Bhagat
e2d3d65c60 fix: Move around group-sync tests (since they require docker services to be running) (#5041)
* Move around tests

* Add missing fixtures + change directory structure up some more

* Add env variables
2025-07-17 22:41:31 +00:00
Raunak Bhagat
b78a6834f5 fix: Have document show up before message starts streaming back (#5006)
* Have document show up before message starts streaming back

* Add docs
2025-07-17 10:17:57 -07:00
Raunak Bhagat
4abe90aa2c fix: Fix Confluence pagination (#5035)
* Re-implement pagination

* Add note

* Fix invalid integration test configs

* Fix other failing test

* Edit failing test

* Revert test

* Revert pagination size

* Add comment on yielding style

* Use fixture instead of manually initializing sql-engine

* Fix failing tests

* Move code back and copy-paste
2025-07-17 14:02:29 +00:00
Raunak Bhagat
de9568844b Add PR labeller job (#4611) 2025-07-16 18:28:18 -07:00
Evan Lohn
34268f9806 fix bug in index swap (#5036) 2025-07-16 23:09:17 +00:00
Chris Weaver
ed75678837 Add suggested helm resource limits (#5032)
* Add resource suggestions for helm

* Adjust README

* fix

* fix lint
2025-07-15 15:52:16 -07:00
Chris Weaver
3bb58a3dd3 Persona simplification r2 (#5031)
* Revert "Revert "Reduce amount of stuff we fetch on `/persona` (#4988)" (#5024)"

This reverts commit f7ed7cd3cd.

* Enhancements / fix re-render

* re-arrange

* greptile
2025-07-15 14:51:40 -07:00
Chris Weaver
4b02feef31 Add option to disable my documents (#5020)
* Add option to disable my documents

* cleanup
2025-07-14 23:16:14 -07:00
369 changed files with 16285 additions and 5685 deletions

2
.github/CODEOWNERS vendored
View File

@@ -1 +1,3 @@
* @onyx-dot-app/onyx-core-team
# Helm charts Owners
/helm/ @justin-tahara

View File

@@ -0,0 +1,49 @@
name: Release Onyx Helm Charts
on:
push:
branches:
- main
permissions: write-all
jobs:
release:
permissions:
contents: write
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install Helm CLI
uses: azure/setup-helm@v4
with:
version: v3.12.1
- name: Add required Helm repositories
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo update
- name: Build chart dependencies
run: |
set -euo pipefail
for chart_dir in deployment/helm/charts/*; do
if [ -f "$chart_dir/Chart.yaml" ]; then
echo "Building dependencies for $chart_dir"
helm dependency build "$chart_dir"
fi
done
- name: Publish Helm charts to gh-pages
uses: stefanprodan/helm-gh-pages@v1.7.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
charts_dir: deployment/helm/charts
branch: gh-pages
commit_username: ${{ github.actor }}
commit_email: ${{ github.actor }}@users.noreply.github.com

View File

@@ -13,6 +13,14 @@ env:
# MinIO
S3_ENDPOINT_URL: "http://localhost:9004"
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
discover-test-dirs:
runs-on: ubuntu-latest

View File

@@ -55,7 +55,25 @@ jobs:
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
run: ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
--set=postgresql.enabled=false \
--set=redis.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
--set=indexCapability.replicaCount=0 \
--set=celery_beat.replicaCount=0 \
--set=celery_worker_heavy.replicaCount=0 \
--set=celery_worker_docprocessing.replicaCount=0 \
--set=celery_worker_light.replicaCount=0 \
--set=celery_worker_monitoring.replicaCount=0 \
--set=celery_worker_primary.replicaCount=0 \
--set=celery_worker_user_files_indexing.replicaCount=0" \
--debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -19,6 +19,10 @@ env:
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
PLATFORM_PAIR: linux-amd64
jobs:
@@ -272,6 +276,10 @@ jobs:
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

38
.github/workflows/pr-labeler.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: PR Labeler
on:
pull_request_target:
branches:
- main
types:
- opened
- reopened
- synchronize
- edited
permissions:
contents: read
pull-requests: write
jobs:
validate_pr_title:
runs-on: ubuntu-latest
steps:
- name: Check PR title for Conventional Commits
env:
PR_TITLE: ${{ github.event.pull_request.title }}
run: |
echo "PR Title: $PR_TITLE"
if [[ ! "$PR_TITLE" =~ ^(feat|fix|docs|test|ci|refactor|perf|chore|revert|build)(\(.+\))?:\ .+ ]]; then
echo "::error::❌ Your PR title does not follow the Conventional Commits format.
This check ensures that all pull requests use clear, consistent titles that help automate changelogs and improve project history.
Please update your PR title to follow the Conventional Commits style.
Here is a link to a blog explaining the reason why we've included the Conventional Commits style into our PR titles: https://xfuture-blog.com/working-with-conventional-commits
**Here are some examples of valid PR titles:**
- feat: add user authentication
- fix(login): handle null password error
- docs(readme): update installation instructions"
exit 1
fi

View File

@@ -19,6 +19,10 @@ env:
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests-mit:
@@ -207,6 +211,10 @@ jobs:
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

View File

@@ -16,8 +16,8 @@ env:
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}

View File

@@ -24,8 +24,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring"
],
@@ -46,8 +46,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring"
],
@@ -226,35 +226,66 @@
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"name": "Celery docfetching",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
],
"presentation": {
"group": "2"
"group": "2"
},
"consoleTitle": "Celery indexing Console"
},
"consoleTitle": "Celery docfetching Console",
"justMyCode": false
},
{
"name": "Celery docprocessing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
"-Q",
"docprocessing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
},
{
"name": "Celery monitoring",
"type": "debugpy",
@@ -303,35 +334,6 @@
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
@@ -426,7 +428,7 @@
},
"args": [
"--filename",
"generated/openapi.json",
"generated/openapi.json"
]
},
{

View File

@@ -23,7 +23,7 @@ from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import (
MULTI_TENANT,
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE,
POSTGRES_DEFAULT_SCHEMA,
TENANT_ID_PREFIX,
)
from onyx.db.models import Base
@@ -271,7 +271,7 @@ async def run_async_migrations() -> None:
) = get_schema_options()
if not schemas and not MULTI_TENANT:
schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE]
schemas = [POSTGRES_DEFAULT_SCHEMA]
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)

View File

@@ -96,7 +96,7 @@ def get_google_drive_documents_from_database() -> list[dict]:
result = bind.execute(
sa.text(
"""
SELECT d.id, cc.id as cc_pair_id
SELECT d.id
FROM document d
JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id
JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id
@@ -109,7 +109,7 @@ def get_google_drive_documents_from_database() -> list[dict]:
documents = []
for row in result:
documents.append({"document_id": row.id, "cc_pair_id": row.cc_pair_id})
documents.append({"document_id": row.id})
return documents

View File

@@ -0,0 +1,115 @@
"""add_indexing_coordination
Revision ID: 2f95e36923e6
Revises: 0816326d83aa
Create Date: 2025-07-10 16:17:57.762182
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f95e36923e6"
down_revision = "0816326d83aa"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add database-based coordination fields (replacing Redis fencing)
op.add_column(
"index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"cancellation_requested",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
# Add batch coordination fields (replacing FileStore state)
op.add_column(
"index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"completed_batches", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"total_failures_batch_level",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"index_attempt",
sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"),
)
# Progress tracking for stall detection
op.add_column(
"index_attempt",
sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column(
"last_batches_completed_count",
sa.Integer(),
nullable=False,
server_default="0",
),
)
# Heartbeat tracking for worker liveness detection
op.add_column(
"index_attempt",
sa.Column(
"heartbeat_counter", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"last_heartbeat_value", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True),
)
# Add index for coordination queries
op.create_index(
"ix_index_attempt_active_coordination",
"index_attempt",
["connector_credential_pair_id", "search_settings_id", "status"],
)
def downgrade() -> None:
# Remove the new index
op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt")
# Remove the new columns
op.drop_column("index_attempt", "last_batches_completed_count")
op.drop_column("index_attempt", "last_progress_time")
op.drop_column("index_attempt", "last_heartbeat_time")
op.drop_column("index_attempt", "last_heartbeat_value")
op.drop_column("index_attempt", "heartbeat_counter")
op.drop_column("index_attempt", "total_chunks")
op.drop_column("index_attempt", "total_failures_batch_level")
op.drop_column("index_attempt", "completed_batches")
op.drop_column("index_attempt", "total_batches")
op.drop_column("index_attempt", "cancellation_requested")
op.drop_column("index_attempt", "celery_task_id")

View File

@@ -9,7 +9,7 @@ Create Date: 2025-06-22 17:33:25.833733
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# revision identifiers, used by Alembic.
revision = "36e9220ab794"
@@ -66,7 +66,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -111,7 +111,7 @@ def upgrade() -> None:
UPDATE "{tenant_id}".kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;

View File

@@ -0,0 +1,30 @@
"""add_doc_metadata_field_in_document_model
Revision ID: 3fc5d75723b3
Revises: 2f95e36923e6
Create Date: 2025-07-28 18:45:37.985406
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "3fc5d75723b3"
down_revision = "2f95e36923e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document",
sa.Column(
"doc_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
)
def downgrade() -> None:
op.drop_column("document", "doc_metadata")

View File

@@ -15,7 +15,7 @@ from datetime import datetime, timedelta
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# revision identifiers, used by Alembic.
@@ -478,7 +478,7 @@ def upgrade() -> None:
# Create GIN index for clustering and normalization
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
@@ -518,7 +518,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -563,7 +563,7 @@ def upgrade() -> None:
UPDATE kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;

View File

@@ -0,0 +1,132 @@
"""add file names to file connector config
Revision ID: 62c3a055a141
Revises: 3fc5d75723b3
Create Date: 2025-07-30 17:01:24.417551
"""
from alembic import op
import sqlalchemy as sa
import json
import os
import logging
# revision identifiers, used by Alembic.
revision = "62c3a055a141"
down_revision = "3fc5d75723b3"
branch_labels = None
depends_on = None
SKIP_FILE_NAME_MIGRATION = (
os.environ.get("SKIP_FILE_NAME_MIGRATION", "true").lower() == "true"
)
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
if SKIP_FILE_NAME_MIGRATION:
logger.info(
"Skipping file name migration. Hint: set SKIP_FILE_NAME_MIGRATION=false to run this migration"
)
return
logger.info("Running file name migration")
# Get connection
conn = op.get_bind()
# Get all FILE connectors with their configs
file_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'FILE'
"""
)
).fetchall()
for connector_id, config in file_connectors:
# Parse config if it's a string
if isinstance(config, str):
config = json.loads(config)
# Get file_locations list
file_locations = config.get("file_locations", [])
# Get display names for each file_id
file_names = []
for file_id in file_locations:
result = conn.execute(
sa.text(
"""
SELECT display_name
FROM file_record
WHERE file_id = :file_id
"""
),
{"file_id": file_id},
).fetchone()
if result:
file_names.append(result[0])
else:
file_names.append(file_id) # Should not happen
# Add file_names to config
new_config = dict(config)
new_config["file_names"] = file_names
# Update the connector
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get connection
conn = op.get_bind()
# Remove file_names from all FILE connectors
file_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'FILE'
"""
)
).fetchall()
for connector_id, config in file_connectors:
# Parse config if it's a string
if isinstance(config, str):
config = json.loads(config)
# Remove file_names if it exists
if "file_names" in config:
new_config = dict(config)
del new_config["file_names"]
# Update the connector
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{
"connector_id": connector_id,
"new_config": json.dumps(new_config),
},
)

View File

@@ -0,0 +1,341 @@
"""tag-fix
Revision ID: 90e3b9af7da4
Revises: 62c3a055a141
Create Date: 2025-08-01 20:58:14.607624
"""
import json
import logging
import os
from typing import cast
from typing import Generator
from alembic import op
import sqlalchemy as sa
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.db.search_settings import SearchSettings
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.constants import AuthType
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "90e3b9af7da4"
down_revision = "62c3a055a141"
branch_labels = None
depends_on = None
SKIP_TAG_FIX = os.environ.get("SKIP_TAG_FIX", "true").lower() == "true"
# override for cloud
if AUTH_TYPE == AuthType.CLOUD:
SKIP_TAG_FIX = True
def set_is_list_for_known_tags() -> None:
"""
Sets is_list to true for all tags that are known to be lists.
"""
LIST_METADATA: list[tuple[str, str]] = [
("CLICKUP", "tags"),
("CONFLUENCE", "labels"),
("DISCOURSE", "tags"),
("FRESHDESK", "emails"),
("GITHUB", "assignees"),
("GITHUB", "labels"),
("GURU", "tags"),
("GURU", "folders"),
("HUBSPOT", "associated_contact_ids"),
("HUBSPOT", "associated_company_ids"),
("HUBSPOT", "associated_deal_ids"),
("HUBSPOT", "associated_ticket_ids"),
("JIRA", "labels"),
("MEDIAWIKI", "categories"),
("ZENDESK", "labels"),
("ZENDESK", "content_tags"),
]
bind = op.get_bind()
for source, key in LIST_METADATA:
bind.execute(
sa.text(
f"""
UPDATE tag
SET is_list = true
WHERE tag_key = '{key}'
AND source = '{source}'
"""
)
)
def set_is_list_for_list_tags() -> None:
"""
Sets is_list to true for all tags which have multiple values for a given
document, key, and source triplet. This only works if we remove old tags
from the database.
"""
bind = op.get_bind()
bind.execute(
sa.text(
"""
UPDATE tag
SET is_list = true
FROM (
SELECT DISTINCT tag.tag_key, tag.source
FROM tag
JOIN document__tag ON tag.id = document__tag.tag_id
GROUP BY tag.tag_key, tag.source, document__tag.document_id
HAVING count(*) > 1
) AS list_tags
WHERE tag.tag_key = list_tags.tag_key
AND tag.source = list_tags.source
"""
)
)
def log_list_tags() -> None:
bind = op.get_bind()
result = bind.execute(
sa.text(
"""
SELECT DISTINCT source, tag_key
FROM tag
WHERE is_list
ORDER BY source, tag_key
"""
)
).fetchall()
logger.info(
"List tags:\n" + "\n".join(f" {source}: {key}" for source, key in result)
)
def remove_old_tags() -> None:
"""
Removes old tags from the database.
Previously, there was a bug where if a document got indexed with a tag and then
the document got reindexed, the old tag would not be removed.
This function removes those old tags by comparing it against the tags in vespa.
"""
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings, future_search_settings
)
# Get the index name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"
for batch in _get_batch_documents_with_multiple_tags():
n_deleted = 0
for document_id in batch:
true_metadata = _get_vespa_metadata(document_id, index_name)
tags = _get_document_tags(document_id)
# identify document__tags to delete
to_delete: list[str] = []
for tag_id, tag_key, tag_value in tags:
true_val = true_metadata.get(tag_key, "")
if (isinstance(true_val, list) and tag_value not in true_val) or (
isinstance(true_val, str) and tag_value != true_val
):
to_delete.append(str(tag_id))
if not to_delete:
continue
# delete old document__tags
bind = op.get_bind()
result = bind.execute(
sa.text(
f"""
DELETE FROM document__tag
WHERE document_id = '{document_id}'
AND tag_id IN ({','.join(to_delete)})
"""
)
)
n_deleted += result.rowcount
logger.info(f"Processed {len(batch)} documents and deleted {n_deleted} tags")
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
result = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_fetch = result.fetchall()
search_settings = (
SearchSettings(**search_settings_fetch[0]._asdict())
if search_settings_fetch
else None
)
result2 = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_future_fetch = result2.fetchall()
search_settings_future = (
SearchSettings(**search_settings_future_fetch[0]._asdict())
if search_settings_future_fetch
else None
)
if not isinstance(search_settings, SearchSettings):
raise RuntimeError(
"current search settings is of type " + str(type(search_settings))
)
if (
not isinstance(search_settings_future, SearchSettings)
and search_settings_future is not None
):
raise RuntimeError(
"future search settings is of type " + str(type(search_settings_future))
)
return search_settings, search_settings_future
def _get_batch_documents_with_multiple_tags(
batch_size: int = 128,
) -> Generator[list[str], None, None]:
"""
Returns a list of document ids which contain a one to many tag.
The document may either contain a list metadata value, or may contain leftover
old tags from reindexing.
"""
offset_clause = ""
bind = op.get_bind()
while True:
batch = bind.execute(
sa.text(
f"""
SELECT DISTINCT document__tag.document_id
FROM tag
JOIN document__tag ON tag.id = document__tag.tag_id
GROUP BY tag.tag_key, tag.source, document__tag.document_id
HAVING count(*) > 1 {offset_clause}
ORDER BY document__tag.document_id
LIMIT {batch_size}
"""
)
).fetchall()
if not batch:
break
doc_ids = [document_id for document_id, in batch]
yield doc_ids
offset_clause = f"AND document__tag.document_id > '{doc_ids[-1]}'"
def _get_vespa_metadata(
document_id: str, index_name: str
) -> dict[str, str | list[str]]:
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
# Document-Selector language
selection = (
f"{index_name}.document_id=='{document_id}' and {index_name}.chunk_id==0"
)
params: dict[str, str | int] = {
"selection": selection,
"wantedDocumentCount": 1,
"fieldSet": f"{index_name}:metadata",
}
with get_vespa_http_client() as client:
resp = client.get(url, params=params)
resp.raise_for_status()
docs = resp.json().get("documents", [])
if not docs:
raise RuntimeError(f"No chunk-0 found for document {document_id}")
# for some reason, metadata is a string
metadata = docs[0]["fields"]["metadata"]
return json.loads(metadata)
def _get_document_tags(document_id: str) -> list[tuple[int, str, str]]:
bind = op.get_bind()
result = bind.execute(
sa.text(
f"""
SELECT tag.id, tag.tag_key, tag.tag_value
FROM tag
JOIN document__tag ON tag.id = document__tag.tag_id
WHERE document__tag.document_id = '{document_id}'
"""
)
).fetchall()
return cast(list[tuple[int, str, str]], result)
def upgrade() -> None:
op.add_column(
"tag",
sa.Column("is_list", sa.Boolean(), nullable=False, server_default="false"),
)
op.drop_constraint(
constraint_name="_tag_key_value_source_uc",
table_name="tag",
type_="unique",
)
op.create_unique_constraint(
constraint_name="_tag_key_value_source_list_uc",
table_name="tag",
columns=["tag_key", "tag_value", "source", "is_list"],
)
set_is_list_for_known_tags()
if SKIP_TAG_FIX:
logger.warning(
"Skipping removal of old tags. "
"This can cause issues when using the knowledge graph, or "
"when filtering for documents by tags."
)
log_list_tags()
return
remove_old_tags()
set_is_list_for_list_tags()
# debug
log_list_tags()
def downgrade() -> None:
# the migration adds and populates the is_list column, and removes old bugged tags
# there isn't a point in adding back the bugged tags, so we just drop the column
op.drop_constraint(
constraint_name="_tag_key_value_source_list_uc",
table_name="tag",
type_="unique",
)
op.create_unique_constraint(
constraint_name="_tag_key_value_source_uc",
table_name="tag",
columns=["tag_key", "tag_value", "source"],
)
op.drop_column("tag", "is_list")

View File

@@ -0,0 +1,33 @@
"""Pause finished user file connectors
Revision ID: b558f51620b4
Revises: 90e3b9af7da4
Create Date: 2025-08-15 17:17:02.456704
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b558f51620b4"
down_revision = "90e3b9af7da4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Set all user file connector credential pairs with ACTIVE status to PAUSED
# This ensures user files don't continue to run indexing tasks after processing
op.execute(
"""
UPDATE connector_credential_pair
SET status = 'PAUSED'
WHERE is_user_file = true
AND status = 'ACTIVE'
"""
)
def downgrade() -> None:
pass

View File

@@ -159,7 +159,7 @@ def _migrate_files_to_postgres() -> None:
# only create external store if we have files to migrate. This line
# makes it so we need to have S3/MinIO configured to run this migration.
external_store = get_s3_file_store(db_session=session)
external_store = get_s3_file_store()
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
@@ -219,7 +219,7 @@ def _migrate_files_to_external_storage() -> None:
# Get database session
bind = op.get_bind()
session = Session(bind=bind)
external_store = get_s3_file_store(db_session=session)
external_store = get_s3_file_store()
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
result = session.execute(

View File

@@ -91,7 +91,7 @@ def export_query_history_task(
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store(db_session).save_file(
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -47,6 +47,7 @@ from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import get_documents_for_connector_credential_pair_limited_columns
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
@@ -58,7 +59,9 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.utils import DocumentRow
from onyx.db.utils import is_retryable_sqlalchemy_error
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -422,7 +425,7 @@ def connector_permission_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
+ f"_{redis_connector.cc_pair_id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
thread_local=False,
)
@@ -498,16 +501,31 @@ def connector_permission_sync_generator_task(
# this is can be used to determine documents that are "missing" and thus
# should no longer be accessible. The decision as to whether we should find
# every document during the doc sync process is connector-specific.
def fetch_all_existing_docs_fn() -> list[str]:
return get_document_ids_for_connector_credential_pair(
def fetch_all_existing_docs_fn(
sort_order: SortOrder | None = None,
) -> list[DocumentRow]:
result = get_documents_for_connector_credential_pair_limited_columns(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
sort_order=sort_order,
)
return list(result)
def fetch_all_existing_docs_ids_fn() -> list[str]:
result = get_document_ids_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
return result
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
document_external_accesses = doc_sync_func(
cc_pair, fetch_all_existing_docs_fn, callback
cc_pair,
fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn,
callback,
)
task_logger.info(

View File

@@ -383,7 +383,7 @@ def connector_external_group_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
+ f"_{redis_connector.cc_pair_id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)

View File

@@ -71,6 +71,19 @@ GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
)
#####
# GitHub
#####
# In seconds, default is 5 minutes
GITHUB_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("GITHUB_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
#####
# Slack
#####
@@ -89,6 +102,19 @@ TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
#####
# SharePoint
#####
# In seconds, default is 30 minutes
SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
)
# In seconds, default is 5 minutes
SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
####
# Celery Job Frequency

View File

@@ -114,7 +114,6 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
def get_usage_report_data(
db_session: Session,
report_display_name: str,
) -> IO:
"""
@@ -128,7 +127,7 @@ def get_usage_report_data(
Returns:
The usage report data.
"""
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
# usage report may be very large, so don't load it all into memory
return file_store.read_file(
file_id=report_display_name, mode="b", use_tempfile=True

View File

@@ -18,9 +18,9 @@
<!-- <document type="danswer_chunk" mode="index" /> -->
{{ document_elements }}
</documents>
<nodes count="75">
<resources vcpu="8.0" memory="64.0Gb" architecture="arm64" storage-type="local"
disk="474.0Gb" />
<nodes count="60">
<resources vcpu="8.0" memory="128.0Gb" architecture="arm64" storage-type="local"
disk="475.0Gb" />
</nodes>
<engine>
<proton>

View File

@@ -6,6 +6,7 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -25,6 +26,7 @@ CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync"
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""
@@ -43,7 +45,7 @@ def confluence_doc_sync(
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.CONFLUENCE,
slim_connector=confluence_connector,

View File

@@ -0,0 +1,294 @@
import json
from collections.abc import Generator
from github import Github
from github.Repository import Repository
from ee.onyx.external_permissions.github.utils import fetch_repository_team_slugs
from ee.onyx.external_permissions.github.utils import form_collaborators_group_id
from ee.onyx.external_permissions.github.utils import form_organization_group_id
from ee.onyx.external_permissions.github.utils import (
form_outside_collaborators_group_id,
)
from ee.onyx.external_permissions.github.utils import get_external_access_permission
from ee.onyx.external_permissions.github.utils import get_repository_visibility
from ee.onyx.external_permissions.github.utils import GitHubVisibility
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.connector import DocMetadata
from onyx.connectors.github.connector import GithubConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
GITHUB_DOC_SYNC_LABEL = "github_doc_sync"
def github_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
"""
Sync GitHub documents with external access permissions.
This function checks each repository for visibility/team changes and updates
document permissions accordingly without using checkpoints.
"""
logger.info(f"Starting GitHub document sync for CC pair ID: {cc_pair.id}")
# Initialize GitHub connector with credentials
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
logger.info("GitHub connector credentials loaded successfully")
if not github_connector.github_client:
logger.error("GitHub client initialization failed")
raise ValueError("github_client is required")
# Get all repositories from GitHub API
logger.info("Fetching all repositories from GitHub API")
try:
repos = []
if github_connector.repositories:
if "," in github_connector.repositories:
# Multiple repositories specified
repos = github_connector.get_github_repos(
github_connector.github_client
)
else:
# Single repository
repos = [
github_connector.get_github_repo(github_connector.github_client)
]
else:
# All repositories
repos = github_connector.get_all_repos(github_connector.github_client)
logger.info(f"Found {len(repos)} repositories to check")
except Exception as e:
logger.error(f"Failed to fetch repositories: {e}")
raise
repo_to_doc_list_map: dict[str, list[DocumentRow]] = {}
# sort order is ascending because we want to get the oldest documents first
existing_docs: list[DocumentRow] = fetch_all_existing_docs_fn(
sort_order=SortOrder.ASC
)
logger.info(f"Found {len(existing_docs)} documents to check")
for doc in existing_docs:
try:
doc_metadata = DocMetadata.model_validate_json(json.dumps(doc.doc_metadata))
if doc_metadata.repo not in repo_to_doc_list_map:
repo_to_doc_list_map[doc_metadata.repo] = []
repo_to_doc_list_map[doc_metadata.repo].append(doc)
except Exception as e:
logger.error(f"Failed to parse doc metadata: {e} for doc {doc.id}")
continue
logger.info(f"Found {len(repo_to_doc_list_map)} documents to check")
# Process each repository individually
for repo in repos:
try:
logger.info(f"Processing repository: {repo.id} (name: {repo.name})")
repo_doc_list: list[DocumentRow] = repo_to_doc_list_map.get(
repo.full_name, []
)
if not repo_doc_list:
logger.warning(
f"No documents found for repository {repo.id} ({repo.name})"
)
continue
current_external_group_ids = repo_doc_list[0].external_user_group_ids or []
# Check if repository has any permission changes
has_changes = _check_repository_for_changes(
repo=repo,
github_client=github_connector.github_client,
current_external_group_ids=current_external_group_ids,
)
if has_changes:
logger.info(
f"Repository {repo.id} ({repo.name}) has changes, updating documents"
)
# Get new external access permissions for this repository
new_external_access = get_external_access_permission(
repo, github_connector.github_client
)
logger.info(
f"Found {len(repo_doc_list)} documents for repository {repo.full_name}"
)
# Yield updated external access for each document
for doc in repo_doc_list:
if callback:
callback.progress(GITHUB_DOC_SYNC_LABEL, 1)
yield DocExternalAccess(
doc_id=doc.id,
external_access=new_external_access,
)
else:
logger.info(
f"Repository {repo.id} ({repo.name}) has no changes, skipping"
)
except Exception as e:
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
logger.info(f"GitHub document sync completed for CC pair ID: {cc_pair.id}")
def _check_repository_for_changes(
repo: Repository,
github_client: Github,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository has any permission changes (visibility or team updates).
"""
logger.info(f"Checking repository {repo.id} ({repo.name}) for changes")
# Check for repository visibility changes using the sample document data
if _is_repo_visibility_changed_from_groups(
repo=repo,
current_external_group_ids=current_external_group_ids,
):
logger.info(f"Repository {repo.id} ({repo.name}) has visibility changes")
return True
# Check for team membership changes if repository is private
if get_repository_visibility(
repo
) == GitHubVisibility.PRIVATE and _teams_updated_from_groups(
repo=repo,
github_client=github_client,
current_external_group_ids=current_external_group_ids,
):
logger.info(f"Repository {repo.id} ({repo.name}) has team changes")
return True
logger.info(f"Repository {repo.id} ({repo.name}) has no changes")
return False
def _is_repo_visibility_changed_from_groups(
repo: Repository,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository visibility has changed by analyzing existing external group IDs.
Args:
repo: GitHub repository object
current_external_group_ids: List of external group IDs from existing document
Returns:
True if visibility has changed
"""
current_repo_visibility = get_repository_visibility(repo)
logger.info(f"Current repository visibility: {current_repo_visibility.value}")
# Build expected group IDs for current visibility
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_collaborators_group_id(repo.id),
)
org_group_id = None
if repo.organization:
org_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_organization_group_id(repo.organization.id),
)
# Determine existing visibility from group IDs
has_collaborators_group = collaborators_group_id in current_external_group_ids
has_org_group = org_group_id and org_group_id in current_external_group_ids
if has_collaborators_group:
existing_repo_visibility = GitHubVisibility.PRIVATE
elif has_org_group:
existing_repo_visibility = GitHubVisibility.INTERNAL
else:
existing_repo_visibility = GitHubVisibility.PUBLIC
logger.info(f"Inferred existing visibility: {existing_repo_visibility.value}")
visibility_changed = existing_repo_visibility != current_repo_visibility
if visibility_changed:
logger.info(
f"Visibility changed for repo {repo.id} ({repo.name}): "
f"{existing_repo_visibility.value} -> {current_repo_visibility.value}"
)
return visibility_changed
def _teams_updated_from_groups(
repo: Repository,
github_client: Github,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository team memberships have changed using existing group IDs.
"""
# Fetch current team slugs for the repository
current_teams = fetch_repository_team_slugs(repo=repo, github_client=github_client)
logger.info(
f"Current teams for repository {repo.id} (name: {repo.name}): {current_teams}"
)
# Build group IDs to exclude from team comparison (non-team groups)
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_collaborators_group_id(repo.id),
)
outside_collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_outside_collaborators_group_id(repo.id),
)
non_team_group_ids = {collaborators_group_id, outside_collaborators_group_id}
# Extract existing team IDs from current external group IDs
existing_team_ids = set()
for group_id in current_external_group_ids:
# Skip all non-team groups, keep only team groups
if group_id not in non_team_group_ids:
existing_team_ids.add(group_id)
# Note: existing_team_ids from DB are already prefixed (e.g., "github__team-slug")
# but current_teams from API are raw team slugs, so we need to add the prefix
current_team_ids = set()
for team_slug in current_teams:
team_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=team_slug,
)
current_team_ids.add(team_group_id)
logger.info(
f"Existing team IDs: {existing_team_ids}, Current team IDs: {current_team_ids}"
)
# Compare actual team IDs to detect changes
teams_changed = current_team_ids != existing_team_ids
if teams_changed:
logger.info(
f"Team changes detected for repo {repo.id} (name: {repo.name}): "
f"existing={existing_team_ids}, current={current_team_ids}"
)
return teams_changed

View File

@@ -0,0 +1,46 @@
from collections.abc import Generator
from github import Repository
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.github.utils import get_external_user_group
from onyx.connectors.github.connector import GithubConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def github_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
if not github_connector.github_client:
raise ValueError("github_client is required")
logger.info("Starting GitHub group sync...")
repos: list[Repository.Repository] = []
if github_connector.repositories:
if "," in github_connector.repositories:
# Multiple repositories specified
repos = github_connector.get_github_repos(github_connector.github_client)
else:
# Single repository (backward compatibility)
repos = [github_connector.get_github_repo(github_connector.github_client)]
else:
# All repositories
repos = github_connector.get_all_repos(github_connector.github_client)
for repo in repos:
try:
for external_group in get_external_user_group(
repo, github_connector.github_client
):
logger.info(f"External group: {external_group}")
yield external_group
except Exception as e:
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")

View File

@@ -0,0 +1,488 @@
from collections.abc import Callable
from enum import Enum
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from github import Github
from github import RateLimitExceededException
from github.GithubException import GithubException
from github.NamedUser import NamedUser
from github.Organization import Organization
from github.PaginatedList import PaginatedList
from github.Repository import Repository
from github.Team import Team
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
from onyx.utils.logger import setup_logger
logger = setup_logger()
class GitHubVisibility(Enum):
"""GitHub repository visibility options."""
PUBLIC = "public"
PRIVATE = "private"
INTERNAL = "internal"
MAX_RETRY_COUNT = 3
T = TypeVar("T")
# Higher-order function to wrap GitHub operations with retry and exception handling
def _run_with_retry(
operation: Callable[[], T],
description: str,
github_client: Github,
retry_count: int = 0,
) -> Optional[T]:
"""Execute a GitHub operation with retry on rate limit and exception handling."""
logger.debug(f"Starting operation '{description}', attempt {retry_count + 1}")
try:
result = operation()
logger.debug(f"Operation '{description}' completed successfully")
return result
except RateLimitExceededException:
if retry_count < MAX_RETRY_COUNT:
sleep_after_rate_limit_exception(github_client)
logger.warning(
f"Rate limit exceeded while {description}. Retrying... "
f"(attempt {retry_count + 1}/{MAX_RETRY_COUNT})"
)
return _run_with_retry(
operation, description, github_client, retry_count + 1
)
else:
error_msg = f"Max retries exceeded for {description}"
logger.exception(error_msg)
raise RuntimeError(error_msg)
except GithubException as e:
logger.warning(f"GitHub API error during {description}: {e}")
return None
except Exception as e:
logger.exception(f"Unexpected error during {description}: {e}")
return None
class UserInfo(BaseModel):
"""Represents a GitHub user with their basic information."""
login: str
name: Optional[str] = None
email: Optional[str] = None
class TeamInfo(BaseModel):
"""Represents a GitHub team with its members."""
name: str
slug: str
members: List[UserInfo]
def _fetch_organization_members(
github_client: Github, org_name: str, retry_count: int = 0
) -> List[UserInfo]:
"""Fetch all organization members including owners and regular members."""
org_members: List[UserInfo] = []
logger.info(f"Fetching organization members for {org_name}")
org = _run_with_retry(
lambda: github_client.get_organization(org_name),
f"get organization {org_name}",
github_client,
)
if not org:
logger.error(f"Failed to fetch organization {org_name}")
raise RuntimeError(f"Failed to fetch organization {org_name}")
member_objs: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: org.get_members(filter_="all"),
f"get members for organization {org_name}",
github_client,
)
or []
)
for member in member_objs:
user_info = UserInfo(login=member.login, name=member.name, email=member.email)
org_members.append(user_info)
logger.info(f"Fetched {len(org_members)} members for organization {org_name}")
return org_members
def _fetch_repository_teams_detailed(
repo: Repository, github_client: Github, retry_count: int = 0
) -> List[TeamInfo]:
"""Fetch teams with access to the repository and their members."""
teams_data: List[TeamInfo] = []
logger.info(f"Fetching teams for repository {repo.full_name}")
team_objs: PaginatedList[Team] | list[Team] = (
_run_with_retry(
lambda: repo.get_teams(),
f"get teams for repository {repo.full_name}",
github_client,
)
or []
)
for team in team_objs:
logger.info(
f"Processing team {team.name} (slug: {team.slug}) for repository {repo.full_name}"
)
members: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: team.get_members(),
f"get members for team {team.name}",
github_client,
)
or []
)
team_members = []
for m in members:
user_info = UserInfo(login=m.login, name=m.name, email=m.email)
team_members.append(user_info)
team_info = TeamInfo(name=team.name, slug=team.slug, members=team_members)
teams_data.append(team_info)
logger.info(f"Team {team.name} has {len(team_members)} members")
logger.info(f"Fetched {len(teams_data)} teams for repository {repo.full_name}")
return teams_data
def fetch_repository_team_slugs(
repo: Repository, github_client: Github, retry_count: int = 0
) -> List[str]:
"""Fetch team slugs with access to the repository."""
logger.info(f"Fetching team slugs for repository {repo.full_name}")
teams_data: List[str] = []
team_objs: PaginatedList[Team] | list[Team] = (
_run_with_retry(
lambda: repo.get_teams(),
f"get teams for repository {repo.full_name}",
github_client,
)
or []
)
for team in team_objs:
teams_data.append(team.slug)
logger.info(f"Fetched {len(teams_data)} team slugs for repository {repo.full_name}")
return teams_data
def _get_collaborators_and_outside_collaborators(
github_client: Github,
repo: Repository,
) -> Tuple[List[UserInfo], List[UserInfo]]:
"""Fetch and categorize collaborators into regular and outside collaborators."""
collaborators: List[UserInfo] = []
outside_collaborators: List[UserInfo] = []
logger.info(f"Fetching collaborators for repository {repo.full_name}")
repo_collaborators: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: repo.get_collaborators(),
f"get collaborators for repository {repo.full_name}",
github_client,
)
or []
)
for collaborator in repo_collaborators:
is_outside = False
# Check if collaborator is outside the organization
if repo.organization:
org: Organization | None = _run_with_retry(
lambda: github_client.get_organization(repo.organization.login),
f"get organization {repo.organization.login}",
github_client,
)
if org is not None:
org_obj = org
membership = _run_with_retry(
lambda: org_obj.has_in_members(collaborator),
f"check membership for {collaborator.login} in org {org_obj.login}",
github_client,
)
is_outside = membership is not None and not membership
info = UserInfo(
login=collaborator.login, name=collaborator.name, email=collaborator.email
)
if repo.organization and is_outside:
outside_collaborators.append(info)
else:
collaborators.append(info)
logger.info(
f"Categorized {len(collaborators)} regular and {len(outside_collaborators)} outside collaborators for {repo.full_name}"
)
return collaborators, outside_collaborators
def form_collaborators_group_id(repository_id: int) -> str:
"""Generate group ID for repository collaborators."""
if not repository_id:
logger.exception("Repository ID is required to generate collaborators group ID")
raise ValueError("Repository ID must be set to generate group ID.")
group_id = f"{repository_id}_collaborators"
return group_id
def form_organization_group_id(organization_id: int) -> str:
"""Generate group ID for organization using organization ID."""
if not organization_id:
logger.exception(
"Organization ID is required to generate organization group ID"
)
raise ValueError("Organization ID must be set to generate group ID.")
group_id = f"{organization_id}_organization"
return group_id
def form_outside_collaborators_group_id(repository_id: int) -> str:
"""Generate group ID for outside collaborators."""
if not repository_id:
logger.exception(
"Repository ID is required to generate outside collaborators group ID"
)
raise ValueError("Repository ID must be set to generate group ID.")
group_id = f"{repository_id}_outside_collaborators"
return group_id
def get_repository_visibility(repo: Repository) -> GitHubVisibility:
"""
Get the visibility of a repository.
Returns GitHubVisibility enum member.
"""
if hasattr(repo, "visibility"):
visibility = repo.visibility
logger.info(
f"Repository {repo.full_name} visibility from attribute: {visibility}"
)
try:
return GitHubVisibility(visibility)
except ValueError:
logger.warning(
f"Unknown visibility '{visibility}' for repo {repo.full_name}, defaulting to private"
)
return GitHubVisibility.PRIVATE
logger.info(f"Repository {repo.full_name} is private")
return GitHubVisibility.PRIVATE
def get_external_access_permission(
repo: Repository, github_client: Github, add_prefix: bool = False
) -> ExternalAccess:
"""
Get the external access permission for a repository.
Uses group-based permissions for efficiency and scalability.
add_prefix: When this method is called during the initial permission sync via the connector,
the group ID isn't prefixed with the source while inserting the document record.
So in that case, set add_prefix to True, allowing the method itself to handle
prefixing. However, when the same method is invoked from doc_sync, our system
already adds the prefix to the group ID while processing the ExternalAccess object.
"""
# We maintain collaborators, and outside collaborators as two separate groups
# instead of adding individual user emails to ExternalAccess.external_user_emails for two reasons:
# 1. Changes in repo collaborators (additions/removals) would require updating all documents.
# 2. Repo permissions can change without updating the repo's updated_at timestamp,
# forcing full permission syncs for all documents every time, which is inefficient.
repo_visibility = get_repository_visibility(repo)
logger.info(
f"Generating ExternalAccess for {repo.full_name}: visibility={repo_visibility.value}"
)
if repo_visibility == GitHubVisibility.PUBLIC:
logger.info(
f"Repository {repo.full_name} is public - allowing access to all users"
)
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
elif repo_visibility == GitHubVisibility.PRIVATE:
logger.info(
f"Repository {repo.full_name} is private - setting up restricted access"
)
collaborators_group_id = form_collaborators_group_id(repo.id)
outside_collaborators_group_id = form_outside_collaborators_group_id(repo.id)
if add_prefix:
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=collaborators_group_id,
)
outside_collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=outside_collaborators_group_id,
)
group_ids = {collaborators_group_id, outside_collaborators_group_id}
team_slugs = fetch_repository_team_slugs(repo, github_client)
if add_prefix:
team_slugs = [
build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=slug,
)
for slug in team_slugs
]
group_ids.update(team_slugs)
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=group_ids,
is_public=False,
)
else:
# Internal repositories - accessible to organization members
logger.info(
f"Repository {repo.full_name} is internal - accessible to org members"
)
org_group_id = form_organization_group_id(repo.organization.id)
if add_prefix:
org_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=org_group_id,
)
group_ids = {org_group_id}
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=group_ids,
is_public=False,
)
def get_external_user_group(
repo: Repository, github_client: Github
) -> list[ExternalUserGroup]:
"""
Get the external user group for a repository.
Creates ExternalUserGroup objects with actual user emails for each permission group.
"""
repo_visibility = get_repository_visibility(repo)
logger.info(
f"Generating ExternalUserGroups for {repo.full_name}: visibility={repo_visibility.value}"
)
if repo_visibility == GitHubVisibility.PRIVATE:
logger.info(f"Processing private repository {repo.full_name}")
collaborators, outside_collaborators = (
_get_collaborators_and_outside_collaborators(github_client, repo)
)
teams = _fetch_repository_teams_detailed(repo, github_client)
external_user_groups = []
user_emails = set()
for collab in collaborators:
if collab.email:
user_emails.add(collab.email)
else:
logger.error(f"Collaborator {collab.login} has no email")
if user_emails:
collaborators_group = ExternalUserGroup(
id=form_collaborators_group_id(repo.id),
user_emails=list(user_emails),
)
external_user_groups.append(collaborators_group)
logger.info(f"Created collaborators group with {len(user_emails)} emails")
# Create group for outside collaborators
user_emails = set()
for collab in outside_collaborators:
if collab.email:
user_emails.add(collab.email)
else:
logger.error(f"Outside collaborator {collab.login} has no email")
if user_emails:
outside_collaborators_group = ExternalUserGroup(
id=form_outside_collaborators_group_id(repo.id),
user_emails=list(user_emails),
)
external_user_groups.append(outside_collaborators_group)
logger.info(
f"Created outside collaborators group with {len(user_emails)} emails"
)
# Create groups for teams
for team in teams:
user_emails = set()
for member in team.members:
if member.email:
user_emails.add(member.email)
else:
logger.error(f"Team member {member.login} has no email")
if user_emails:
team_group = ExternalUserGroup(
id=team.slug,
user_emails=list(user_emails),
)
external_user_groups.append(team_group)
logger.info(
f"Created team group {team.name} with {len(user_emails)} emails"
)
logger.info(
f"Created {len(external_user_groups)} ExternalUserGroups for private repository {repo.full_name}"
)
return external_user_groups
if repo_visibility == GitHubVisibility.INTERNAL:
logger.info(f"Processing internal repository {repo.full_name}")
org_group_id = form_organization_group_id(repo.organization.id)
org_members = _fetch_organization_members(
github_client, repo.organization.login
)
user_emails = set()
for member in org_members:
if member.email:
user_emails.add(member.email)
else:
logger.error(f"Org member {member.login} has no email")
org_group = ExternalUserGroup(
id=org_group_id,
user_emails=list(user_emails),
)
logger.info(
f"Created organization group with {len(user_emails)} emails for internal repository {repo.full_name}"
)
return [org_group]
logger.info(f"Repository {repo.full_name} is public - no user groups needed")
return []

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timezone
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
@@ -35,6 +36,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.google_drive.permission_retrieval import (
get_permissions_by_ids,
)
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.connector import GoogleDriveConnector
@@ -169,6 +170,7 @@ def get_external_access_for_raw_gdrive_file(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -1,6 +1,7 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -17,6 +18,7 @@ JIRA_DOC_SYNC_TAG = "jira_doc_sync"
def jira_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
jira_connector = JiraConnector(
@@ -26,7 +28,7 @@ def jira_doc_sync(
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.JIRA,
slim_connector=jira_connector,

View File

@@ -5,6 +5,8 @@ from typing import Protocol
from typing import TYPE_CHECKING
from onyx.context.search.models import InferenceChunk
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
# Avoid circular imports
if TYPE_CHECKING:
@@ -15,14 +17,34 @@ if TYPE_CHECKING:
class FetchAllDocumentsFunction(Protocol):
"""Protocol for a function that fetches all document IDs for a connector credential pair."""
"""Protocol for a function that fetches documents for a connector credential pair.
def __call__(self) -> list[str]:
This protocol defines the interface for functions that retrieve documents
from the database, typically used in permission synchronization workflows.
"""
def __call__(
self,
sort_order: SortOrder | None,
) -> list[DocumentRow]:
"""
Returns a list of document IDs for a connector credential pair.
Fetches documents for a connector credential pair.
"""
...
This is typically used to determine which documents should no longer be
accessible during the document sync process.
class FetchAllDocumentsIdsFunction(Protocol):
"""Protocol for a function that fetches document IDs for a connector credential pair.
This protocol defines the interface for functions that retrieve document IDs
from the database, typically used in permission synchronization workflows.
"""
def __call__(
self,
) -> list[str]:
"""
Fetches document IDs for a connector credential pair.
"""
...
@@ -32,6 +54,7 @@ DocSyncFuncType = Callable[
[
"ConnectorCredentialPair",
FetchAllDocumentsFunction,
FetchAllDocumentsIdsFunction,
Optional["IndexingHeartbeatInterface"],
],
Generator["DocExternalAccess", None, None],

View File

@@ -0,0 +1,36 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
SHAREPOINT_DOC_SYNC_TAG = "sharepoint_doc_sync"
def sharepoint_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
sharepoint_connector = SharepointConnector(
**cc_pair.connector.connector_specific_config,
)
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.SHAREPOINT,
slim_connector=sharepoint_connector,
label=SHAREPOINT_DOC_SYNC_TAG,
)

View File

@@ -0,0 +1,63 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def sharepoint_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
"""Sync SharePoint groups and their members"""
# Get site URLs from connector config
connector_config = cc_pair.connector.connector_specific_config
# Create SharePoint connector instance and load credentials
connector = SharepointConnector(**connector_config)
connector.load_credentials(cc_pair.credential.credential_json)
if not connector.msal_app:
raise RuntimeError("MSAL app not initialized in connector")
if not connector.sp_tenant_domain:
raise RuntimeError("Tenant domain not initialized in connector")
# Get site descriptors from connector (either configured sites or all sites)
site_descriptors = connector.site_descriptors or connector.fetch_sites()
if not site_descriptors:
raise RuntimeError("No SharePoint sites found for group sync")
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
# Create client context for the site using connector's MSAL app
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:
logger.debug(
f"Found group: {group.id} with {len(group.user_emails)} members"
)
yield group

View File

@@ -0,0 +1,658 @@
import re
from collections import deque
from typing import Any
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
logger = setup_logger()
# These values represent different types of SharePoint principals used in permission assignments
USER_PRINCIPAL_TYPE = 1 # Individual user accounts
ANONYMOUS_USER_PRINCIPAL_TYPE = 3 # Anonymous/unauthenticated users (public access)
AZURE_AD_GROUP_PRINCIPAL_TYPE = 4 # Azure Active Directory security groups
SHAREPOINT_GROUP_PRINCIPAL_TYPE = 8 # SharePoint site groups (local to the site)
MICROSOFT_DOMAIN = ".onmicrosoft"
# Limited Access role type, limited access is a travel through permission not a actual permission
LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
name: str
login_name: str
principal_type: int
class GroupsResult(BaseModel):
groups_to_emails: dict[str, set[str]]
found_public_group: bool
def _get_azuread_group_guid_by_name(
graph_client: GraphClient, group_name: str
) -> str | None:
try:
# Search for groups by display name
groups = sleep_and_retry(
graph_client.groups.filter(f"displayName eq '{group_name}'").get(),
"get_azuread_group_guid_by_name",
)
if groups and len(groups) > 0:
return groups[0].id
return None
except Exception as e:
logger.error(f"Failed to get Azure AD group GUID for name {group_name}: {e}")
return None
def _extract_guid_from_claims_token(claims_token: str) -> str | None:
try:
# Pattern to match GUID in claims token
# Claims tokens often have format: c:0o.c|provider|GUID_suffix
guid_pattern = r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})"
match = re.search(guid_pattern, claims_token, re.IGNORECASE)
if match:
return match.group(1)
return None
except Exception as e:
logger.error(f"Failed to extract GUID from claims token {claims_token}: {e}")
return None
def _get_group_guid_from_identifier(
graph_client: GraphClient, identifier: str
) -> str | None:
try:
# Check if it's already a GUID
guid_pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
if re.match(guid_pattern, identifier, re.IGNORECASE):
return identifier
# Check if it's a SharePoint claims token
if identifier.startswith("c:0") and "|" in identifier:
guid = _extract_guid_from_claims_token(identifier)
if guid:
logger.info(f"Extracted GUID {guid} from claims token {identifier}")
return guid
# Try to search by display name as fallback
return _get_azuread_group_guid_by_name(graph_client, identifier)
except Exception as e:
logger.error(f"Failed to get group GUID from identifier {identifier}: {e}")
return None
def _get_security_group_owners(graph_client: GraphClient, group_id: str) -> list[str]:
try:
# Get group owners using Graph API
group = graph_client.groups[group_id]
owners = sleep_and_retry(
group.owners.get_all(page_loaded=lambda _: None),
"get_security_group_owners",
)
owner_emails: list[str] = []
logger.info(f"Owners: {owners}")
for owner in owners:
owner_data = owner.to_json()
# Extract email from the JSON data
mail: str | None = owner_data.get("mail")
user_principal_name: str | None = owner_data.get("userPrincipalName")
# Check if owner is a user and has an email
if mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
owner_emails.append(mail)
elif user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
owner_emails.append(user_principal_name)
logger.info(
f"Retrieved {len(owner_emails)} owners from security group {group_id}"
)
return owner_emails
except Exception as e:
logger.error(f"Failed to get security group owners for group {group_id}: {e}")
return []
def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
try:
# First try to get the list item directly from the drive item
if hasattr(drive_item, "listItem"):
list_item = drive_item.listItem
if list_item:
# Load the list item properties to get the ID
sleep_and_retry(list_item.get(), "get_sharepoint_list_item_id")
if hasattr(list_item, "id") and list_item.id:
return str(list_item.id)
# The SharePoint list item ID is typically available in the sharepointIds property
sharepoint_ids = getattr(drive_item, "sharepoint_ids", None)
if sharepoint_ids and hasattr(sharepoint_ids, "listItemId"):
return sharepoint_ids.listItemId
# Alternative: try to get it from the properties
properties = getattr(drive_item, "properties", None)
if properties:
# Sometimes the SharePoint list item ID is in the properties
for prop_name, prop_value in properties.items():
if "listitemid" in prop_name.lower():
return str(prop_value)
return None
except Exception as e:
logger.error(
f"Error getting SharePoint list item ID for item {drive_item.id}: {e}"
)
raise e
def _is_public_item(drive_item: DriveItem) -> bool:
is_public = False
try:
permissions = sleep_and_retry(
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
)
for permission in permissions:
if permission.link and (
permission.link.scope == "anonymous"
or permission.link.scope == "organization"
):
is_public = True
break
return is_public
except Exception as e:
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
return False
def _is_public_login_name(login_name: str) -> bool:
# Patterns that indicate public access
# This list is derived from the below link
# https://learn.microsoft.com/en-us/answers/questions/2085339/guid-in-the-loginname-of-site-user-everyone-except
public_login_patterns: list[str] = [
"c:0-.f|rolemanager|spo-grid-all-users/",
"c:0(.s|true",
]
for pattern in public_login_patterns:
if pattern in login_name:
logger.info(f"Login name {login_name} is public")
return True
return False
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
def _get_group_name_with_suffix(
login_name: str, group_name: str, graph_client: GraphClient
) -> str:
ad_group_suffix = _get_group_guid_from_identifier(graph_client, login_name)
return f"{group_name}_{ad_group_suffix}"
def _get_sharepoint_groups(
client_context: ClientContext, group_name: str, graph_client: GraphClient
) -> tuple[set[SharepointGroup], set[str]]:
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
def process_users(users: list[Any]) -> None:
nonlocal groups, user_emails
for user in users:
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
user, "user_principal_name"
):
if user.user_principal_name:
email = user.user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
else:
logger.warning(
f"User don't have a user principal name: {user.login_name}"
)
elif user.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = user.title
if user.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
user.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=user.login_name,
principal_type=user.principal_type,
name=name,
)
)
group = client_context.web.site_groups.get_by_name(group_name)
sleep_and_retry(
group.users.get_all(page_loaded=process_users), "get_sharepoint_groups"
)
return groups, user_emails
def _get_azuread_groups(
graph_client: GraphClient, group_name: str
) -> tuple[set[SharepointGroup], set[str]]:
group_id = _get_group_guid_from_identifier(graph_client, group_name)
if not group_id:
logger.error(f"Failed to get Azure AD group GUID for name {group_name}")
return set(), set()
group = graph_client.groups[group_id]
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
def process_members(members: list[Any]) -> None:
nonlocal groups, user_emails
for member in members:
member_data = member.to_json()
# Check for user-specific attributes
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
display_name = member_data.get("displayName") or member_data.get(
"display_name"
)
# Check object attributes directly (if available)
is_user = False
is_group = False
# Users typically have userPrincipalName or mail
if user_principal_name or (mail and "@" in str(mail)):
is_user = True
# Groups typically have displayName but no userPrincipalName
elif display_name and not user_principal_name:
# Additional check: try to access group-specific properties
if (
hasattr(member, "groupTypes")
or member_data.get("groupTypes") is not None
):
is_group = True
# Or check if it has an 'id' field typical for groups
elif member_data.get("id") and not user_principal_name:
is_group = True
# Check the object type name (fallback)
if not is_user and not is_group:
obj_type = type(member).__name__.lower()
if "user" in obj_type:
is_user = True
elif "group" in obj_type:
is_group = True
# Process based on identification
if is_user:
if user_principal_name:
email = user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
elif mail:
email = mail
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
logger.info(f"Added user: {user_principal_name or mail}")
elif is_group:
if not display_name:
logger.error(f"No display name for group: {member_data.get('id')}")
continue
name = _get_group_name_with_suffix(
member_data.get("id", ""), display_name, graph_client
)
groups.add(
SharepointGroup(
login_name=member_data.get("id", ""), # Use ID for groups
principal_type=AZURE_AD_GROUP_PRINCIPAL_TYPE,
name=name,
)
)
logger.info(f"Added group: {name}")
else:
# Log unidentified members for debugging
logger.warning(f"Could not identify member type for: {member_data}")
sleep_and_retry(
group.members.get_all(page_loaded=process_members), "get_azuread_groups"
)
owner_emails = _get_security_group_owners(graph_client, group_id)
user_emails.update(owner_emails)
return groups, user_emails
def _get_groups_and_members_recursively(
client_context: ClientContext,
graph_client: GraphClient,
groups: set[SharepointGroup],
) -> GroupsResult:
"""
Get all groups and their members recursively.
"""
group_queue: deque[SharepointGroup] = deque(groups)
visited_groups: set[str] = set()
visited_group_name_to_emails: dict[str, set[str]] = {}
while group_queue:
group = group_queue.popleft()
if group.login_name in visited_groups:
continue
visited_groups.add(group.login_name)
visited_group_name_to_emails[group.name] = set()
logger.info(
f"Processing group: {group.name} principal type: {group.principal_type}"
)
if group.principal_type == SHAREPOINT_GROUP_PRINCIPAL_TYPE:
group_info, user_emails = _get_sharepoint_groups(
client_context, group.login_name, graph_client
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
return GroupsResult(groups_to_emails={}, found_public_group=True)
group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
return GroupsResult(
groups_to_emails=visited_group_name_to_emails, found_public_group=False
)
def get_external_access_from_sharepoint(
client_context: ClientContext,
graph_client: GraphClient,
drive_name: str | None,
drive_item: DriveItem | None,
site_page: dict[str, Any] | None,
add_prefix: bool = False,
) -> ExternalAccess:
"""
Get external access information from SharePoint.
"""
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
group_ids: set[str] = set()
# Add all members to a processing set first
def add_user_and_group_to_sets(
role_assignments: RoleAssignmentCollection,
) -> None:
nonlocal user_emails, groups
for assignment in role_assignments:
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
if (
role_definition_binding.role_type_kind
not in LIMITED_ACCESS_ROLE_TYPES
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
):
is_limited_access = False
break
# Skip if the role is only Limited Access, because this is not a actual permission its a travel through permission
if is_limited_access:
logger.info(
"Skipping assignment because it has only Limited Access role"
)
continue
if assignment.member:
member = assignment.member
if member.principal_type == USER_PRINCIPAL_TYPE and hasattr(
member, "user_principal_name"
):
email = member.user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
elif member.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = member.title
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
member.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=member.login_name,
principal_type=member.principal_type,
name=name,
)
)
if drive_item and drive_name:
# Here we check if the item have have any public links, if so we return early
is_public = _is_public_item(drive_item)
if is_public:
logger.info(f"Item {drive_item.id} is public")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
item_id = _get_sharepoint_list_item_id(drive_item)
if not item_id:
raise RuntimeError(
f"Failed to get SharePoint list item ID for item {drive_item.id}"
)
if drive_name == "Shared Documents":
drive_name = "Documents"
item = client_context.web.lists.get_by_title(drive_name).items.get_by_id(
item_id
)
sleep_and_retry(
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
page_loaded=add_user_and_group_to_sets,
),
"get_external_access_from_sharepoint",
)
elif site_page:
site_url = site_page.get("webUrl")
site_pages = client_context.web.lists.get_by_title("Site Pages")
client_context.load(site_pages)
client_context.execute_query()
site_pages.items.get_by_url(site_url).role_assignments.expand(
["Member", "RoleDefinitionBindings"]
).get_all(page_loaded=add_user_and_group_to_sets).execute_query()
else:
raise RuntimeError("No drive item or site page provided")
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups
)
# If the site is public, w have default groups assigned to it, so we return early
if groups_and_members.found_public_group:
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
for group_name, _ in groups_and_members.groups_to_emails.items():
if add_prefix:
group_name = build_ext_group_name_for_onyx(
group_name, DocumentSource.SHAREPOINT
)
group_ids.add(group_name.lower())
logger.info(f"User emails: {len(user_emails)}")
logger.info(f"Group IDs: {len(group_ids)}")
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_ids,
is_public=False,
)
def get_sharepoint_external_groups(
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None:
nonlocal groups
for assignment in role_assignments:
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
if (
role_definition_binding.role_type_kind
not in LIMITED_ACCESS_ROLE_TYPES
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
):
is_limited_access = False
break
# Skip if the role assignment is only Limited Access, because this is not a actual permission its
# a travel through permission
if is_limited_access:
logger.info(
"Skipping assignment because it has only Limited Access role"
)
continue
if assignment.member:
member = assignment.member
if member.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = member.title
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
member.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=member.login_name,
principal_type=member.principal_type,
name=name,
)
)
sleep_and_retry(
client_context.web.role_assignments.expand(
["Member", "RoleDefinitionBindings"]
).get_all(page_loaded=add_group_to_sets),
"get_sharepoint_external_groups",
)
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups
)
# We don't have any direct way to check if the site is public, so we check if any public group is present
if groups_and_members.found_public_group:
return []
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -3,6 +3,7 @@ from collections.abc import Generator
from slack_sdk import WebClient
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
@@ -130,6 +131,7 @@ def _get_slack_document_access(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -7,12 +7,18 @@ from pydantic import BaseModel
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
from ee.onyx.external_permissions.github.doc_sync import github_doc_sync
from ee.onyx.external_permissions.github.group_sync import github_group_sync
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
@@ -20,10 +26,13 @@ from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from ee.onyx.external_permissions.sharepoint.doc_sync import sharepoint_doc_sync
from ee.onyx.external_permissions.sharepoint.group_sync import sharepoint_group_sync
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync
from onyx.configs.constants import DocumentSource
@@ -63,6 +72,7 @@ class SyncConfig(BaseModel):
def mock_doc_sync(
cc_pair: "ConnectorCredentialPair",
fetch_all_docs_fn: FetchAllDocumentsFunction,
fetch_all_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: Optional["IndexingHeartbeatInterface"],
) -> Generator["DocExternalAccess", None, None]:
"""Mock doc sync function for testing - returns empty list since permissions are fetched during indexing"""
@@ -117,6 +127,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
initial_index_should_sync=False,
),
),
DocumentSource.GITHUB: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=GITHUB_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=github_doc_sync,
initial_index_should_sync=True,
),
group_sync_config=GroupSyncConfig(
group_sync_frequency=GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY,
group_sync_func=github_group_sync,
group_sync_is_cc_pair_agnostic=False,
),
),
DocumentSource.SALESFORCE: SyncConfig(
censoring_config=CensoringConfig(
chunk_censoring_func=censor_salesforce_chunks,
@@ -138,6 +160,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
initial_index_should_sync=True,
),
),
DocumentSource.SHAREPOINT: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=sharepoint_doc_sync,
initial_index_should_sync=True,
),
group_sync_config=GroupSyncConfig(
group_sync_frequency=SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY,
group_sync_func=sharepoint_group_sync,
group_sync_is_cc_pair_agnostic=False,
),
),
}

View File

@@ -1,6 +1,7 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -18,6 +19,7 @@ TEAMS_DOC_SYNC_LABEL = "teams_doc_sync"
def teams_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
teams_connector = TeamsConnector(
@@ -27,7 +29,7 @@ def teams_doc_sync(
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.TEAMS,
slim_connector=teams_connector,

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
@@ -14,7 +14,7 @@ logger = setup_logger()
def generic_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
doc_source: DocumentSource,
slim_connector: SlimConnector,
@@ -62,9 +62,9 @@ def generic_doc_sync(
)
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id=}")
existing_doc_ids = set(fetch_all_existing_docs_fn())
existing_doc_ids: list[str] = fetch_all_existing_docs_ids_fn()
missing_doc_ids = existing_doc_ids - newly_fetched_doc_ids
missing_doc_ids = set(existing_doc_ids) - newly_fetched_doc_ids
if not missing_doc_ids:
return

View File

@@ -206,7 +206,7 @@ def _handle_standard_answers(
restate_question_blocks = get_restate_blocks(
msg=query_msg.message,
is_bot_msg=message_info.is_bot_msg,
is_slash_command=message_info.is_slash_command,
)
answer_blocks = build_standard_answer_blocks(

View File

@@ -134,15 +134,14 @@ def ee_fetch_settings() -> EnterpriseSettings:
def put_logo(
file: UploadFile,
is_logotype: bool = False,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
upload_logo(file=file, is_logotype=is_logotype)
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")

View File

@@ -6,7 +6,6 @@ from typing import IO
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
@@ -99,9 +98,7 @@ def guess_file_type(filename: str) -> str:
return "application/octet-stream"
def upload_logo(
db_session: Session, file: UploadFile | str, is_logotype: bool = False
) -> bool:
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
content: IO[Any]
if isinstance(file, str):
@@ -129,7 +126,7 @@ def upload_logo(
display_name = file.filename
file_type = file.content_type or "image/jpeg"
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.save_file(
content=content,
display_name=display_name,

View File

@@ -1,5 +1,6 @@
import re
from typing import cast
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
@@ -73,6 +74,7 @@ def _get_final_context_doc_indices(
def _convert_packet_stream_to_response(
packets: ChatPacketStream,
chat_session_id: UUID,
) -> ChatBasicResponse:
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
@@ -216,6 +218,8 @@ def _convert_packet_stream_to_response(
if answer:
response.answer_citationless = remove_answer_citations(answer)
response.chat_session_id = chat_session_id
return response
@@ -237,13 +241,36 @@ def handle_simplified_chat_message(
if not chat_message_req.message:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
# Handle chat session creation if chat_session_id is not provided
if chat_message_req.chat_session_id is None:
if chat_message_req.persona_id is None:
raise HTTPException(
status_code=400,
detail="Either chat_session_id or persona_id must be provided",
)
# Create a new chat session with the provided persona_id
try:
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave empty for simple API
user_id=user.id if user else None,
persona_id=chat_message_req.persona_id,
)
chat_session_id = new_chat_session.id
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
else:
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_chain(
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
chat_session_id=chat_session_id, db_session=db_session
)
except Exception:
parent_message = get_or_create_root_message(
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
chat_session_id=chat_session_id, db_session=db_session
)
if (
@@ -258,7 +285,7 @@ def handle_simplified_chat_message(
retrieval_options = chat_message_req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_message_req.chat_session_id,
chat_session_id=chat_session_id,
parent_message_id=parent_message.id,
message=chat_message_req.message,
file_descriptors=[],
@@ -283,7 +310,7 @@ def handle_simplified_chat_message(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets)
return _convert_packet_stream_to_response(packets, chat_session_id)
@router.post("/send-message-simple-with-history")
@@ -403,4 +430,4 @@ def handle_send_message_simple_with_history(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets)
return _convert_packet_stream_to_response(packets, chat_session.id)

View File

@@ -41,11 +41,13 @@ class DocumentSearchRequest(ChunkContext):
class BasicCreateChatMessageRequest(ChunkContext):
"""Before creating messages, be sure to create a chat_session and get an id
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
Note, for simplicity this option only allows for a single linear chain of messages
"""
chat_session_id: UUID
chat_session_id: UUID | None = None
# Optional persona_id to create a new chat session if chat_session_id is not provided
persona_id: int | None = None
# New message contents
message: str
# Defaults to using retrieval with no additional filters
@@ -62,6 +64,12 @@ class BasicCreateChatMessageRequest(ChunkContext):
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
if self.chat_session_id is None and self.persona_id is None:
raise ValueError("Either chat_session_id or persona_id must be provided")
return self
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -171,6 +179,9 @@ class ChatBasicResponse(BaseModel):
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
# Chat session ID for tracking conversation continuity
chat_session_id: UUID | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -358,7 +358,7 @@ def get_query_history_export_status(
# If task is None, then it's possible that the task has already finished processing.
# Therefore, we should then check if the export file has already been stored inside of the file-store.
# If that *also* doesn't exist, then we can return a 404.
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
report_name = construct_query_history_report_name(request_id)
has_file = file_store.has_file(
@@ -385,7 +385,7 @@ def download_query_history_csv(
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
report_name = construct_query_history_report_name(request_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
has_file = file_store.has_file(
file_id=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -53,7 +53,7 @@ def read_usage_report(
db_session: Session = Depends(get_session),
) -> Response:
try:
file = get_usage_report_data(db_session, report_name)
file = get_usage_report_data(report_name)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -67,7 +67,7 @@ def generate_chat_messages_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)
@@ -99,7 +99,7 @@ def generate_user_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)
@@ -112,7 +112,7 @@ def create_new_usage_report(
period: tuple[datetime, datetime] | None,
) -> UsageReportMetadata:
report_id = str(uuid.uuid4())
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
messages_file_id = generate_chat_messages_report(
db_session, file_store, report_id, period

View File

@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
store_ee_settings(final_enterprise_settings)
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
def _seed_logo(logo_path: str | None) -> None:
if logo_path:
logger.notice("Uploading logo")
upload_logo(db_session=db_session, file=logo_path)
upload_logo(file=logo_path)
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
@@ -245,7 +245,7 @@ def seed_db() -> None:
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_logo(seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)

View File

@@ -203,6 +203,8 @@ def generate_simple_sql(
if state.kg_entity_temp_view_name is None:
raise ValueError("kg_entity_temp_view_name is not set")
sql_statement_display: str | None = None
## STEP 3 - articulate goals
stream_write_step_activities(writer, _KG_STEP_NR)
@@ -381,7 +383,18 @@ def generate_simple_sql(
raise e
logger.debug(f"A3 - sql_statement after correction: {sql_statement}")
# display sql statement with view names replaced by general view names
sql_statement_display = sql_statement.replace(
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
)
sql_statement_display = sql_statement_display.replace(
state.kg_rel_temp_view_name, "<your_relationship_view_name>"
)
sql_statement_display = sql_statement_display.replace(
state.kg_entity_temp_view_name, "<your_entity_view_name>"
)
logger.debug(f"A3 - sql_statement after correction: {sql_statement_display}")
# Get SQL for source documents
@@ -409,7 +422,20 @@ def generate_simple_sql(
"relationship_table", rel_temp_view
)
logger.debug(f"A3 source_documents_sql: {source_documents_sql}")
if source_documents_sql:
source_documents_sql_display = source_documents_sql.replace(
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
)
source_documents_sql_display = source_documents_sql_display.replace(
state.kg_rel_temp_view_name, "<your_relationship_view_name>"
)
source_documents_sql_display = source_documents_sql_display.replace(
state.kg_entity_temp_view_name, "<your_entity_view_name>"
)
else:
source_documents_sql_display = "(No source documents SQL generated)"
logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}")
scalar_result = None
query_results = None
@@ -435,7 +461,13 @@ def generate_simple_sql(
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
except Exception as e:
# TODO: raise error on frontend
logger.error(f"Error executing SQL query: {e}")
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
raise e
@@ -459,8 +491,14 @@ def generate_simple_sql(
for source_document_result in query_source_document_results
]
except Exception as e:
# No stopping here, the individualized SQL query is not mandatory
# TODO: raise error on frontend
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
logger.error(f"Error executing Individualized SQL query: {e}")
else:
@@ -493,11 +531,11 @@ def generate_simple_sql(
if reasoning:
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
if main_sql_statement:
if sql_statement_display:
stream_write_step_answer_explicit(
writer,
step_nr=_KG_STEP_NR,
answer=f" \n Generated SQL: {main_sql_statement}",
answer=f" \n Generated SQL: {sql_statement_display}",
)
stream_close_step_answer(writer, _KG_STEP_NR)

View File

@@ -24,13 +24,14 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatt
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -39,6 +40,7 @@ from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import PlainFormatter
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -92,7 +94,13 @@ def on_task_prerun(
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
pass
# Reset any per-task logging context so that prefixes (e.g. pruning_ctx)
# from a previous task executed in the same worker process do not leak
# into the next task's log messages. This fixes incorrect [CC Pair:/Index Attempt]
# prefixes observed when a pruning task finishes and an indexing task
# runs in the same process.
LoggerContextVars.reset()
def on_task_postrun(
@@ -145,8 +153,11 @@ def on_task_postrun(
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
# NOTE: we want to remove the `Redis*` classes, prefer to just have functions to
# do these things going forward. In short, things should generally be like the doc
# sync task rather than the others below
if task_id.startswith(DOCUMENT_SYNC_PREFIX):
r.srem(DOCUMENT_SYNC_TASKSET_KEY, task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
@@ -470,7 +481,8 @@ class TenantContextFilter(logging.Filter):
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id:
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
# Match the 8 character tenant abbreviation used in OnyxLoggingAdapter
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:8]
record.name = f"[t:{tenant_id}]"
else:
record.name = ""

View File

@@ -231,10 +231,7 @@ class DynamicTenantScheduler(PersistentScheduler):
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
if current_tasks != new_tasks:
return False
return True
return current_tasks == new_tasks
@beat_init.connect

View File

@@ -0,0 +1,102 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
]
)

View File

@@ -12,7 +12,7 @@ from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -21,7 +21,7 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.indexing")
celery_app.config_from_object("onyx.background.celery.configs.docprocessing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@@ -60,7 +60,7 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
@@ -108,6 +108,6 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -116,6 +116,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -9,6 +9,7 @@ from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.result import AsyncResult
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@@ -18,9 +19,7 @@ from redis.lock import Lock as RedisLock
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.indexing.utils import (
get_unfenced_index_attempt_ids,
)
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
@@ -29,13 +28,10 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
@@ -156,35 +152,63 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
RedisGlobalConnectorCredentialPair.reset_all(r)
# NOTE: we want to remove the `Redis*` classes, prefer to just have functions
# This is the preferred way to do this going forward
reset_document_sync(r)
RedisDocumentSet.reset_all(r)
RedisUserGroup.reset_all(r)
RedisConnectorDelete.reset_all(r)
RedisConnectorPrune.reset_all(r)
RedisConnectorIndex.reset_all(r)
RedisConnectorStop.reset_all(r)
RedisConnectorPermissionSync.reset_all(r)
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
# This uses database coordination instead of Redis fencing
with get_session_with_current_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
# Get potentially orphaned attempts (those with active status and task IDs)
potentially_orphaned_ids = IndexingCoordination.get_orphaned_index_attempt_ids(
db_session
)
for attempt_id in potentially_orphaned_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
# handle case where not started or docfetching is done but indexing is not
if (
not attempt
or not attempt.celery_task_id
or attempt.total_batches is not None
):
continue
failure_reason = (
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
logger.exception(
f"Marking attempt {attempt.id} as canceled due to validation error 2"
)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
# Check if the Celery task actually exists
try:
result: AsyncResult = AsyncResult(attempt.celery_task_id)
# If the task is not in PENDING state, it exists in Celery
if result.state != "PENDING":
continue
# Task is orphaned - mark as failed
failure_reason = (
f"Orphaned index attempt found on startup - Celery task not found: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id} "
f"celery_task_id={attempt.celery_task_id}"
)
logger.warning(failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
except Exception:
# If we can't check the task status, be conservative and continue
logger.warning(
f"Could not verify Celery task status on startup for attempt {attempt.id}, "
f"task_id={attempt.celery_task_id}"
)
@worker_ready.connect
@@ -291,7 +315,7 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -26,7 +26,7 @@ def celery_get_unacked_length(r: Redis) -> int:
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queue are "prefetched", so this gives
Unacked entries belonging to the indexing queues are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()

View File

@@ -1,3 +1,5 @@
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
@@ -8,10 +10,12 @@ import httpx
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
@@ -22,12 +26,14 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:
return {doc.id for doc in doc_batch}
doc_batch: Iterator[list[Document]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
def extract_ids_from_runnable_connector(
@@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_generator = None
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.poll_source(start=start, end=end)
)
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
doc_batch_processing_func = document_batch_to_ids
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
)(lambda x: x)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
return all_connector_doc_ids

View File

@@ -0,0 +1,22 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Docfetching worker configuration
worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,5 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
from onyx.configs.app_configs import CELERY_WORKER_DOCPROCESSING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -24,6 +24,6 @@ task_acks_late = shared_config.task_acks_late
# which means a duplicate run might change the task state unexpectedly
# task_track_started = True
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_concurrency = CELERY_WORKER_DOCPROCESSING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -100,24 +100,6 @@ beat_task_templates: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,

View File

@@ -40,9 +40,11 @@ from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
@@ -69,13 +71,21 @@ def revoke_tasks_blocking_deletion(
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=redis_connector.cc_pair_id,
search_settings_id=search_settings.id,
limit=1,
db_session=db_session,
)
if (
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
and recent_index_attempts[0].celery_task_id
):
app.control.revoke(recent_index_attempts[0].celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
f"Revoked indexing task {recent_index_attempts[0].celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
@@ -183,12 +193,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()
@@ -281,8 +286,16 @@ def try_generate_document_cc_pair_cleanup_tasks(
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings.id,
limit=1,
db_session=db_session,
)
if (
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
):
raise TaskDependencyError(
"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "

View File

@@ -0,0 +1,638 @@
import multiprocessing
import os
import time
import traceback
from time import sleep
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
from onyx.background.celery.tasks.docprocessing.tasks import ConnectorIndexingLogBuilder
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
from onyx.background.celery.tasks.models import DocProcessingContext
from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus
from onyx.background.celery.tasks.models import SimpleJobResult
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_docfetching import run_docfetching_entrypoint
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector import RedisConnector
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
def _verify_indexing_attempt(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
) -> None:
"""
Verify that the indexing attempt exists and is in the correct state.
"""
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"docfetching_task - IndexAttempt not found: attempt_id={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
if attempt.connector_credential_pair_id != cc_pair_id:
raise SimpleJobException(
f"docfetching_task - CC pair mismatch: "
f"expected={cc_pair_id} actual={attempt.connector_credential_pair_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
if attempt.search_settings_id != search_settings_id:
raise SimpleJobException(
f"docfetching_task - Search settings mismatch: "
f"expected={search_settings_id} actual={attempt.search_settings_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
if attempt.status not in [
IndexingStatus.NOT_STARTED,
IndexingStatus.IN_PROGRESS,
]:
raise SimpleJobException(
f"docfetching_task - Invalid attempt status: "
f"attempt_id={index_attempt_id} status={attempt.status}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
# Check for cancellation
if IndexingCoordination.check_cancellation_requested(
db_session, index_attempt_id
):
raise SimpleJobException(
f"docfetching_task - Cancellation requested: attempt_id={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
logger.info(
f"docfetching_task - IndexAttempt verified: "
f"attempt_id={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
def docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
"""
This function is run in a SimpleJob as a new process. It is responsible for validating
some stuff, but basically it just calls run_indexing_entrypoint.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Start heartbeat for this indexing attempt
heartbeat_thread, stop_event = start_heartbeat(index_attempt_id)
try:
_docfetching_task(
app, index_attempt_id, cc_pair_id, search_settings_id, is_ee, tenant_id
)
finally:
stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting
def _docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
# TODO: remove all fences, cause all signals to be set in postgres
if redis_connector.delete.fenced:
raise SimpleJobException(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
)
if redis_connector.stop.fenced:
raise SimpleJobException(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
# Verify the indexing attempt exists and is valid
# This replaces the Redis fence payload waiting
_verify_indexing_attempt(index_attempt_id, cc_pair_id, search_settings_id)
try:
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"Index attempt not found: index_attempt={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise SimpleJobException(
f"cc_pair not found: cc_pair={cc_pair_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
# define a callback class
callback = IndexingCallback(
redis_connector,
)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_docfetching_entrypoint(
app,
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
except ConnectorValidationError:
raise SimpleJobException(
f"Indexing task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# special bulletproofing ... truncate long exception messages
# for exception types that require more args, this will fail
# thus the try/except
try:
sanitized_e = type(e)(str(e)[:1024])
sanitized_e.__traceback__ = e.__traceback__
raise sanitized_e
except Exception:
raise e
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
os._exit(0) # ensure process exits cleanly
def process_job_result(
job: SimpleJob,
connector_source: str | None,
index_attempt_id: int,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
if job.process:
result.exit_code = job.process.exitcode
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# Workaround: check that the total number of batches is set, since this only
# happens when docfetching completed successfully
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt and index_attempt.total_batches is not None:
ignore_exitcode = True
if ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
)
else:
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
job_level_exception = job.exception()
result.exception_str = (
f"Docfetching returned exit code {result.exit_code} "
f"with exception: {job_level_exception}"
)
return result
@shared_task(
name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def docfetching_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
) -> None:
"""
This task is the entrypoint for the full indexing pipeline, which is composed of two tasks:
docfetching and docprocessing.
This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task.
This task spawns a new process for a new scheduled index attempt. That
new process (which runs the docfetching_task function) does the following:
1) determines parameters of the indexing attempt (which connector indexing function to run,
start and end time, from prev checkpoint or not), then run that connector. Specifically,
connectors are responsible for reading data from an outside source and converting it to Onyx documents.
At the moment these two steps (reading external data and converting to an Onyx document)
are not parallelized in most connectors; that's a subject for future work.
Each document batch produced by step 1 is stored in the file store, and a docprocessing task is spawned
to process it. docprocessing involves the steps listed below.
2) upserts documents to postgres (index_doc_batch_prepare)
3) chunks each document (optionally adds context for contextual rag)
4) embeds chunks (embed_chunks_with_failure_handling) via a call to the model server
5) write chunks to vespa (write_chunks_to_vector_db_with_backoff)
6) update document and indexing metadata in postgres
7) pulls all document IDs from the source and compares those IDs to locally stored documents and deletes
all locally stored IDs missing from the most recently pulled document ID list
Some important notes:
Invariants:
- docfetching proxy tasks are spawned by check_for_indexing. The proxy then runs the docfetching_task wrapped in a watchdog.
The watchdog is responsible for monitoring the docfetching_task and marking the index attempt as failed
if it is not making progress.
- All docprocessing tasks are spawned by a docfetching task.
- all docfetching tasks, docprocessing tasks, and document batches in the file store are
associated with a specific index attempt.
- the index attempt status is the source of truth for what is currently happening with the index attempt.
It is coupled with the creation/running of docfetching and docprocessing tasks as much as possible.
How we deal with failures/ partial indexing:
- non-checkpointed connectors/ new runs in general => delete the old document batches from the file store and do the new run
- checkpointed connectors + resuming from checkpoint => reissue the old document batches and do a new run
Misc:
- most inter-process communication is handled in postgres, some is still in redis and we're trying to remove it
- Heartbeat spawned in docfetching and docprocessing is how check_for_indexing monitors liveliness
- progress based liveliness check: if nothing is done in 3-6 hours, mark the attempt as failed
- TODO: task level timeouts (i.e. a connector stuck in an infinite loop)
Comments below are from the old version and some may no longer be valid.
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
Some more Richard notes:
celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
# TODO: remove dependence on Redis
start = time.monotonic()
result = SimpleJobResult()
ctx = DocProcessingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
task_logger.info(
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}")
job = client.submit(
docfetching_task,
self.app,
index_attempt_id,
cc_pair_id,
search_settings_id,
global_version.is_ee_version(),
tenant_id,
)
if not job or not job.process:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
return
# Ensure the process has moved out of the starting state
num_waits = 0
while True:
if num_waits > 15:
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
job.release()
return
if job.process.is_alive() or job.process.exitcode is not None:
break
sleep(1)
num_waits += 1
task_logger.info(
log_builder.build(
"Indexing watchdog - spawn succeeded",
pid=str(job.process.pid),
)
)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
eager_load_cc_pair=True,
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
while True:
sleep(5)
time.monotonic()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, index_attempt_id, log_builder
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
# No need to expose full stack trace for validation errors
result.exception_str = str(e)
else:
result.exception_str = traceback.format_exc()
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, ctx.index_attempt_id)
# only mark failures if not already terminal,
# otherwise we're overwriting potential real stack traces
if attempt and not attempt.status.is_terminal():
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
normalized_exception_str = "None"
if result.exception_str:
normalized_exception_str = result.exception_str.replace(
"\n", "\\n"
).replace('"', '\\"')
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=f'"{normalized_exception_str}"',
elapsed=f"{elapsed:.2f}s",
)
)
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
)
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)

View File

@@ -0,0 +1,36 @@
import threading
from sqlalchemy import update
from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import IndexAttempt
def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]:
"""Start a heartbeat thread for the given index attempt"""
stop_event = threading.Event()
def heartbeat_loop() -> None:
while not stop_event.wait(INDEXING_WORKER_HEARTBEAT_INTERVAL):
try:
with get_session_with_current_tenant() as db_session:
db_session.execute(
update(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1)
)
db_session.commit()
except Exception:
# Silently continue if heartbeat fails
pass
thread = threading.Thread(target=heartbeat_loop, daemon=True)
thread.start()
return thread, stop_event
def stop_heartbeat(thread: threading.Thread, stop_event: threading.Event) -> None:
"""Stop the heartbeat thread"""
stop_event.set()
thread.join(timeout=5) # Wait up to 5 seconds for clean shutdown

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,8 @@
import time
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
@@ -12,8 +10,6 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
@@ -21,27 +17,19 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import setup_logger
@@ -50,54 +38,6 @@ logger = setup_logger()
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
class IndexingCallbackBase(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
@@ -123,10 +63,9 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
return False
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
def progress(self, tag: str, amount: int) -> None:
"""Amount isn't used yet."""
@@ -171,186 +110,28 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
raise
class IndexingCallback(IndexingCallbackBase):
# NOTE: we're in the process of removing all fences from indexing; this will
# eventually no longer be used. For now, it is used only for connector pausing.
class IndexingCallback(IndexingHeartbeatInterface):
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
redis_connector_index: RedisConnectorIndex,
):
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
self.redis_connector = redis_connector
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
def should_stop(self) -> bool:
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
# included to satisfy old interface
def progress(self, tag: str, amount: int) -> None:
self.redis_connector_index.set_active()
self.redis_connector_index.set_connector_active()
super().progress(tag, amount)
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
pass
def validate_indexing_fence(
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
f"index_attempt={payload.index_attempt_id}",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed: "
f"index_attempt={payload.index_attempt_id}",
)
redis_connector_index.reset()
return
def validate_indexing_fences(
tenant_id: str,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
"""Validates all indexing fences for this tenant ... aka makes sure
indexing tasks sent to celery are still in flight.
"""
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
continue
with get_session_with_current_tenant() as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
r_celery,
db_session,
)
lock_beat.reacquire()
return
# NOTE: The validate_indexing_fence and validate_indexing_fences functions have been removed
# as they are no longer needed with database-based coordination. The new validation is
# handled by validate_active_indexing_attempts in the main indexing tasks module.
def is_in_repeated_error_state(
@@ -414,10 +195,12 @@ def should_index(
)
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
task_logger.info(
f"_should_index: "
f"cc_pair={cc_pair.id} "
f"connector={cc_pair.connector_id} "
f"refresh_freq={connector.refresh_freq}"
)
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
@@ -517,7 +300,7 @@ def should_index(
return True
def try_creating_indexing_task(
def try_creating_docfetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
@@ -531,10 +314,11 @@ def try_creating_indexing_task(
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
Now uses database-based coordination instead of Redis fencing.
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
@@ -547,61 +331,42 @@ def try_creating_indexing_task(
if not acquired:
return None
redis_connector_index: RedisConnectorIndex
index_attempt_id = None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
# Basic status checks
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
# Generate custom task ID for tracking
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
# set a basic fence to start
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
# and cleans them up
# therefore we must create the attempt and the task after the fence goes up
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
# Try to create a new index attempt using database coordination
# This replaces the Redis fencing mechanism
index_attempt_id = IndexingCoordination.try_create_index_attempt(
db_session=db_session,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
celery_task_id=custom_task_id,
from_beginning=reindex,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
if index_attempt_id is None:
# Another indexing attempt is already running
return None
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_INDEXING
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
)
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
# Send the task to Celery
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -613,14 +378,18 @@ def try_creating_indexing_task(
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
task_logger.info(
f"Created docfetching task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id} "
f"attempt_id={index_attempt_id} "
f"celery_task_id={custom_task_id}"
)
return index_attempt_id
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
@@ -628,9 +397,10 @@ def try_creating_indexing_task(
f"search_settings={search_settings.id}"
)
# Clean up on failure
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
mark_attempt_failed(index_attempt_id, db_session)
return None
finally:
if lock.owned():

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,110 @@
from enum import Enum
from pydantic import BaseModel
class DocProcessingContext(BaseModel):
tenant_id: str
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
SPAWN_NOT_ALIVE = (
"spawn_not_alive" # spawn succeeded but process did not come alive
)
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
FENCE_READINESS_TIMEOUT = (
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
)
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
INDEX_ATTEMPT_MISMATCH = (
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_VALIDATION_ERROR = (
"connector_validation_error" # the connector validation failed
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
# consolidate once we know more
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@property
def code(self) -> int:
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
}
return _ENUM_TO_CODE[self]
@classmethod
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
}
if code in _CODE_TO_ENUM:
return _CODE_TO_ENUM[code]
return IndexingWatchdogTerminalStatus.UNDEFINED
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None

View File

@@ -147,7 +147,7 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"indexing_queue_length": "indexing",
"docprocessing_queue_length": "docprocessing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
@@ -882,7 +882,13 @@ def monitor_celery_queues_helper(
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
n_user_files_indexing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
@@ -896,14 +902,20 @@ def monitor_celery_queues_helper(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_indexing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
n_docfetching_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.DOCPROCESSING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(n_indexing_prefetched)} "
f"docfetching={n_docfetching} "
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
f"docprocessing={n_docprocessing} "
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
f"user_files_indexing={n_user_files_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "

View File

@@ -22,7 +22,7 @@ from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallbackBase
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -47,7 +47,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.tag import delete_orphan_tags__no_commit
@@ -70,9 +69,9 @@ logger = setup_logger()
def _get_pruning_block_expiration() -> int:
"""
Compute the expiration time for the pruning block signal.
Base expiration is 3600 seconds (1 hour), multiplied by the beat multiplier only in MULTI_TENANT mode.
Base expiration is 60 seconds (1 minute), multiplied by the beat multiplier only in MULTI_TENANT mode.
"""
base_expiration = 3600 # seconds
base_expiration = 60 # seconds
if not MULTI_TENANT:
return base_expiration
@@ -145,10 +144,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
last_pruned = cc_pair.connector.time_created
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
return datetime.now(timezone.utc) >= next_prune
@shared_task(
@@ -280,6 +276,9 @@ def try_creating_prune_generator_task(
if not ALLOW_SIMULTANEOUS_PRUNING:
count = redis_connector.prune.get_active_task_count()
if count > 0:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} no simultaneous pruning allowed"
)
return None
LOCK_TIMEOUT = 30
@@ -293,6 +292,9 @@ def try_creating_prune_generator_task(
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} lock not acquired"
)
return None
try:
@@ -464,7 +466,7 @@ def connector_pruning_generator_task(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)
@@ -516,9 +518,6 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
search_settings = get_current_search_settings(db_session)
redis_connector.new_index(search_settings.id)
callback = PruneCallback(
0,
redis_connector,

View File

@@ -0,0 +1,178 @@
import time
from typing import cast
from uuid import uuid4
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.document import construct_document_id_select_by_needs_sync
from onyx.db.document import count_documents_by_needs_sync
from onyx.utils.logger import setup_logger
# Redis keys for document sync tracking
DOCUMENT_SYNC_PREFIX = "documentsync"
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
logger = setup_logger()
def is_document_sync_fenced(r: Redis) -> bool:
"""Check if document sync tasks are currently in progress."""
return bool(r.exists(DOCUMENT_SYNC_FENCE_KEY))
def get_document_sync_payload(r: Redis) -> int | None:
"""Get the initial number of tasks that were created."""
bytes_result = r.get(DOCUMENT_SYNC_FENCE_KEY)
if bytes_result is None:
return None
return int(cast(int, bytes_result))
def get_document_sync_remaining(r: Redis) -> int:
"""Get the number of tasks still pending completion."""
return cast(int, r.scard(DOCUMENT_SYNC_TASKSET_KEY))
def set_document_sync_fence(r: Redis, payload: int | None) -> None:
"""Set up the fence and register with active fences."""
if payload is None:
r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
r.delete(DOCUMENT_SYNC_FENCE_KEY)
return
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
def delete_document_sync_taskset(r: Redis) -> None:
"""Clear the document sync taskset."""
r.delete(DOCUMENT_SYNC_TASKSET_KEY)
def reset_document_sync(r: Redis) -> None:
"""Reset all document sync tracking data."""
r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
r.delete(DOCUMENT_SYNC_TASKSET_KEY)
r.delete(DOCUMENT_SYNC_FENCE_KEY)
def generate_document_sync_tasks(
r: Redis,
max_tasks: int,
celery_app: Celery,
db_session: Session,
lock: RedisLock,
tenant_id: str,
) -> tuple[int, int]:
"""Generate sync tasks for all documents that need syncing.
Args:
r: Redis client
max_tasks: Maximum number of tasks to generate
celery_app: Celery application instance
db_session: Database session
lock: Redis lock for coordination
tenant_id: Tenant identifier
Returns:
tuple[int, int]: (tasks_generated, total_docs_found)
"""
last_lock_time = time.monotonic()
num_tasks_sent = 0
num_docs = 0
# Get all documents that need syncing
stmt = construct_document_id_select_by_needs_sync()
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
# Reacquire lock periodically to prevent timeout
if current_time - last_lock_time >= (CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4):
lock.reacquire()
last_lock_time = current_time
num_docs += 1
# Create a unique task ID
custom_task_id = f"{DOCUMENT_SYNC_PREFIX}_{uuid4()}"
# Add to the tracking taskset in Redis BEFORE creating the celery task
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
# Create the Celery task
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
)
num_tasks_sent += 1
if num_tasks_sent >= max_tasks:
break
return num_tasks_sent, num_docs
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
max_tasks: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
) -> int | None:
# the fence is up, do nothing
if is_document_sync_fenced(r):
return None
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
logger.info("No stale documents found. Skipping sync tasks generation.")
return None
logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch."
)
logger.info("generate_document_sync_tasks starting for all documents.")
# Generate all tasks in one pass
result = generate_document_sync_tasks(
r, max_tasks, celery_app, db_session, lock_beat, tenant_id
)
if result is None:
return None
tasks_generated, total_docs = result
if tasks_generated >= max_tasks:
logger.info(
f"generate_document_sync_tasks reached the task generation limit: "
f"tasks_generated={tasks_generated} max_tasks={max_tasks}"
)
else:
logger.info(
f"generate_document_sync_tasks finished for all documents. "
f"tasks_generated={tasks_generated} total_docs_found={total_docs}"
)
set_document_sync_fence(r, tasks_generated)
return tasks_generated

View File

@@ -20,14 +20,19 @@ from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocument
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_FENCE_KEY
from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_payload
from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_remaining
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.background.celery.tasks.vespa.document_sync import (
try_generate_stale_document_sync_tasks,
)
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import count_documents_by_needs_sync
from onyx.db.document import get_document
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import delete_document_set
@@ -47,10 +52,6 @@ from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
@@ -166,8 +167,11 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
continue
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
# NOTE: removing the "Redis*" classes, prefer to just have functions to
# do these things going forward. In short, things should generally be like the doc
# sync task rather than the others
if key_str == DOCUMENT_SYNC_FENCE_KEY:
monitor_document_sync_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
@@ -203,82 +207,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
return True
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
max_tasks: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
) -> int | None:
# the fence is up, do nothing
redis_global_ccpair = RedisGlobalConnectorCredentialPair(r)
if redis_global_ccpair.fenced:
return None
redis_global_ccpair.delete_taskset()
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
tasks_remaining = max_tasks
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(
tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
continue
if result[1] == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
)
total_tasks_generated += result[0]
tasks_remaining -= result[0]
if tasks_remaining <= 0:
break
if tasks_remaining <= 0:
task_logger.info(
f"RedisConnector.generate_tasks reached the task generation limit: "
f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}"
)
else:
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
redis_global_ccpair.set_fence(total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
celery_app: Celery,
document_set_id: int,
@@ -433,19 +361,18 @@ def try_generate_user_group_sync_tasks(
return tasks_generated
def monitor_connector_taskset(r: Redis) -> None:
redis_global_ccpair = RedisGlobalConnectorCredentialPair(r)
initial_count = redis_global_ccpair.payload
def monitor_document_sync_taskset(r: Redis) -> None:
initial_count = get_document_sync_payload(r)
if initial_count is None:
return
remaining = redis_global_ccpair.get_remaining()
remaining = get_document_sync_remaining(r)
task_logger.info(
f"Stale document sync progress: remaining={remaining} initial={initial_count}"
f"Document sync progress: remaining={remaining} initial={initial_count}"
)
if remaining == 0:
redis_global_ccpair.reset()
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
reset_document_sync(r)
task_logger.info(f"Successfully synced all documents. count={initial_count}")
def monitor_document_set_taskset(

View File

@@ -10,7 +10,7 @@ set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.indexing import celery_app
from onyx.background.celery.apps.docfetching import celery_app
return celery_app

View File

@@ -0,0 +1,18 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.docprocessing import celery_app
return celery_app
app = get_app()

View File

@@ -33,7 +33,7 @@ def save_checkpoint(
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.save_file(
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
@@ -52,11 +52,11 @@ def save_checkpoint(
def load_checkpoint(
db_session: Session, index_attempt_id: int, connector: BaseConnector
index_attempt_id: int, connector: BaseConnector
) -> ConnectorCheckpoint:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointedConnector):
@@ -71,7 +71,7 @@ def get_latest_valid_checkpoint(
window_start: datetime,
window_end: datetime,
connector: BaseConnector,
) -> ConnectorCheckpoint:
) -> tuple[ConnectorCheckpoint, bool]:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
@@ -83,7 +83,7 @@ def get_latest_valid_checkpoint(
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
# where we make no progress. Only do this if we have had at least
# _NUM_RECENT_ATTEMPTS_TO_CONSIDER completed attempts.
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
if len(checkpoint_candidates) >= _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
had_any_progress = False
for candidate in checkpoint_candidates:
if (
@@ -99,7 +99,7 @@ def get_latest_valid_checkpoint(
f"found for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return connector.build_dummy_checkpoint()
return connector.build_dummy_checkpoint(), False
# filter out any candidates that don't meet the criteria
checkpoint_candidates = [
@@ -140,11 +140,10 @@ def get_latest_valid_checkpoint(
logger.info(
f"No valid checkpoint found for cc_pair={cc_pair_id}. Starting from scratch."
)
return checkpoint
return checkpoint, False
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
connector=connector,
)
@@ -153,14 +152,14 @@ def get_latest_valid_checkpoint(
f"Failed to load checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Falling back to default checkpoint."
)
return checkpoint
return checkpoint, False
logger.info(
f"Using checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
f"{previous_checkpoint}"
)
return previous_checkpoint
return previous_checkpoint, True
def get_index_attempts_with_old_checkpoints(
@@ -201,7 +200,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None

View File

@@ -153,10 +153,9 @@ class SimpleJob:
if self._exception is None and self.queue and not self.queue.empty():
self._exception = self.queue.get() # Get exception from queue
if self._exception:
return self._exception
return f"Job with ID '{self.id}' did not report an exception."
return (
self._exception or f"Job with ID '{self.id}' did not report an exception."
)
class SimpleJobClient:

View File

@@ -1,3 +1,4 @@
import sys
import time
import traceback
from collections import defaultdict
@@ -5,7 +6,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from pydantic import BaseModel
from celery import Celery
from sqlalchemy.orm import Session
from onyx.access.access import source_should_fetch_permissions_during_indexing
@@ -18,18 +19,25 @@ from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -49,26 +57,29 @@ from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
logger = setup_logger()
logger = setup_logger(propagate=False)
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
@@ -146,6 +157,10 @@ def _get_connector_runner(
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
cleaned_batch = []
for doc in doc_batch:
if sys.getsizeof(doc) > MAX_FILE_SIZE_BYTES:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
@@ -180,25 +195,11 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
return cleaned_batch
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class RunIndexingContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
db_session_temp: Session,
cc_pair_id: int,
search_settings_status: IndexModelStatus,
index_attempt_id: int,
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
@@ -206,27 +207,38 @@ def _check_connector_and_attempt_status(
"""
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
raise RuntimeError(f"CC pair {cc_pair_id} not found in DB.")
if (
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
and search_settings_status != IndexModelStatus.FUTURE
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
raise RuntimeError("Connector was disabled mid run")
raise ConnectorStopSignal(f"Connector {cc_pair_loop.status.value.lower()}")
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_loop:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if index_attempt_loop.status == IndexingStatus.CANCELED:
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
error_str = ""
if index_attempt_loop.error_msg:
error_str = f" Original error: {index_attempt_loop.error_msg}"
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
f"Index Attempt is not running, status is {index_attempt_loop.status}.{error_str}"
)
if index_attempt_loop.celery_task_id is None:
raise RuntimeError(f"Index attempt {index_attempt_id} has no celery task id")
# TODO: delete from here if ends up unused
def _check_failure_threshold(
total_failures: int,
document_count: int,
@@ -257,6 +269,9 @@ def _check_failure_threshold(
)
# NOTE: this is the old run_indexing function that the new decoupled approach
# is based on. Leaving this for comparison purposes, but if you see this comment
# has been here for >2 month, please delete this function.
def _run_indexing(
db_session: Session,
index_attempt_id: int,
@@ -271,7 +286,12 @@ def _run_indexing(
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
index_attempt_start = get_index_attempt(
db_session_temp,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
@@ -292,7 +312,7 @@ def _run_indexing(
index_attempt_start.connector_credential_pair.last_successful_index_time
is not None
)
ctx = RunIndexingContext(
ctx = DocExtractionContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
@@ -317,6 +337,7 @@ def _run_indexing(
and (from_beginning or not has_successful_attempt)
),
search_settings_status=index_attempt_start.search_settings.status,
doc_extraction_complete_batch_num=None,
)
last_successful_index_poll_range_end = (
@@ -384,19 +405,6 @@ def _run_indexing(
httpx_client=HttpxPool.get("vespa"),
)
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
@@ -416,7 +424,9 @@ def _run_indexing(
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
index_attempt = get_index_attempt(
db_session_temp, index_attempt_id, eager_load_cc_pair=True
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
@@ -439,7 +449,7 @@ def _run_indexing(
):
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
checkpoint, _ = get_latest_valid_checkpoint(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
@@ -496,7 +506,10 @@ def _run_indexing(
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
db_session_temp,
ctx.cc_pair_id,
ctx.search_settings_status,
index_attempt_id,
)
# save record of any failures at the connector level
@@ -554,7 +567,16 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
index_pipeline_result = indexing_pipeline(
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
@@ -814,7 +836,8 @@ def _run_indexing(
)
def run_indexing_entrypoint(
def run_docfetching_entrypoint(
app: Celery,
index_attempt_id: int,
tenant_id: str,
connector_credential_pair_id: int,
@@ -828,11 +851,10 @@ def run_indexing_entrypoint(
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
token = INDEX_ATTEMPT_INFO_CONTEXTVAR.set(
(connector_credential_pair_id, index_attempt_id)
)
with get_session_with_current_tenant() as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
@@ -846,18 +868,519 @@ def run_indexing_entrypoint(
credential_id = attempt.connector_credential_pair.credential_id
logger.info(
f"Indexing starting{tenant_str}: "
f"Docfetching starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
with get_session_with_current_tenant() as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
connector_document_extraction(
app,
index_attempt_id,
attempt.connector_credential_pair_id,
attempt.search_settings_id,
tenant_id,
callback,
)
logger.info(
f"Indexing finished{tenant_str}: "
f"Docfetching finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token)
def connector_document_extraction(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""Extract documents from connector and queue them for indexing pipeline processing.
This is the first part of the split indexing process that runs the connector
and extracts documents, storing them in the filestore for later processing.
"""
start_time = time.monotonic()
logger.info(
f"Document extraction starting: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"tenant={tenant_id}"
)
# Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint)
batch_storage = get_document_batch_storage(cc_pair_id, index_attempt_id)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt = None
last_batch_num = 0 # used to continue from checkpointing
# comes from _run_indexing
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found")
if index_attempt.search_settings is None:
raise ValueError("Search settings must be set for indexing")
# Clear the indexing trigger if it was set, to prevent duplicate indexing attempts
if index_attempt.connector_credential_pair.indexing_trigger is not None:
logger.info(
"Clearing indexing trigger: "
f"cc_pair={index_attempt.connector_credential_pair.id} "
f"trigger={index_attempt.connector_credential_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
index_attempt.connector_credential_pair.id, None, db_session
)
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
from_beginning = index_attempt.from_beginning
has_successful_attempt = (
index_attempt.connector_credential_pair.last_successful_index_time
is not None
)
earliest_index_time = (
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
)
should_fetch_permissions_during_indexing = (
index_attempt.connector_credential_pair.access_type == AccessType.SYNC
and source_should_fetch_permissions_during_indexing(db_connector.source)
and is_primary
# if we've already successfully indexed, let the doc_sync job
# take care of doc-level permissions
and (from_beginning or not has_successful_attempt)
)
# Set up time windows for polling
last_successful_index_poll_range_end = (
earliest_index_time
if from_beginning
else get_last_successful_attempt_poll_range_end(
cc_pair_id=cc_pair_id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
)
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_poll_range_end, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
db_session=db_session,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# set time range in db
index_attempt.poll_range_start = window_start
index_attempt.poll_range_end = window_end
db_session.commit()
# TODO: maybe memory tracer here
# Set up connector runner
connector_runner = _get_connector_runner(
db_session=db_session,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
include_permissions=should_fetch_permissions_during_indexing,
)
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling
# OR
# if the last attempt was successful
if index_attempt.from_beginning or (
most_recent_attempt and most_recent_attempt.status.is_successful()
):
logger.info(
f"Cleaning up all old batches for index attempt {index_attempt_id} before starting new run"
)
batch_storage.cleanup_all_batches()
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
logger.info(
f"Getting latest valid checkpoint for index attempt {index_attempt_id}"
)
checkpoint, resuming_from_checkpoint = get_latest_valid_checkpoint(
db_session=db_session,
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
# checkpoint resumption OR the connector already finished.
if (
isinstance(connector_runner.connector, CheckpointedConnector)
and resuming_from_checkpoint
) or (
most_recent_attempt
and most_recent_attempt.total_batches is not None
and not checkpoint.has_more
):
reissued_batch_count, completed_batches = reissue_old_batches(
batch_storage,
index_attempt_id,
cc_pair_id,
tenant_id,
app,
most_recent_attempt,
)
last_batch_num = reissued_batch_count + completed_batches
index_attempt.completed_batches = completed_batches
db_session.commit()
else:
logger.info(
f"Cleaning up all batches for index attempt {index_attempt_id} before starting new run"
)
# for non-checkpointed connectors, throw out batches from previous unsuccessful attempts
# because we'll be getting those documents again anyways.
batch_storage.cleanup_all_batches()
# Save initial checkpoint
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
try:
batch_num = last_batch_num # starts at 0 if no last batch
total_doc_batches_queued = 0
total_failures = 0
document_count = 0
# Main extraction loop
while checkpoint.has_more:
logger.info(
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback and callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
# will exception if the connector/index attempt is marked as paused/failed
with get_session_with_current_tenant() as db_session_tmp:
_check_connector_and_attempt_status(
db_session_tmp,
cc_pair_id,
index_attempt.search_settings.status,
index_attempt_id,
)
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session:
create_index_attempt_error(
index_attempt_id,
cc_pair_id,
failure,
db_session,
)
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
# Save checkpoint if provided
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing task, so if no batch we can just continue
if not document_batch:
continue
# Clean documents and create batch
doc_batch_cleaned = strip_null_characters(document_batch)
batch_description = []
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
memory_tracer.increment_and_maybe_trace()
# Store documents in storage
batch_storage.store_batch(batch_num, doc_batch_cleaned)
# Create processing task data
processing_batch_data = {
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": batch_num, # 0-indexed
}
# Queue document processing task
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
batch_num += 1
total_doc_batches_queued += 1
logger.info(
f"Queued document processing batch: "
f"batch_num={batch_num} "
f"docs={len(doc_batch_cleaned)} "
f"attempt={index_attempt_id}"
)
# Check checkpoint size periodically
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# Save latest checkpoint
# NOTE: checkpointing is used to track which batches have
# been sent to the filestore, NOT which batches have been fully indexed
# as it used to be.
with get_session_with_current_tenant() as db_session:
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
elapsed_time = time.monotonic() - start_time
logger.info(
f"Document extraction completed: "
f"attempt={index_attempt_id} "
f"batches_queued={total_doc_batches_queued} "
f"elapsed={elapsed_time:.2f}s"
)
# Set total batches in database to signal extraction completion.
# Used by check_for_indexing to determine if the index attempt is complete.
with get_session_with_current_tenant() as db_session:
IndexingCoordination.set_total_batches(
db_session=db_session,
index_attempt_id=index_attempt_id,
total_batches=batch_num,
)
except Exception as e:
logger.exception(
f"Document extraction failed: "
f"attempt={index_attempt_id} "
f"error={str(e)}"
)
# Do NOT clean up batches on failure; future runs will use those batches
# while docfetching will continue from the saved checkpoint if one exists
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if is_primary:
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {db_connector.id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=db_connector.id,
credential_id=db_credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
else:
with get_session_with_current_tenant() as db_session_temp:
# don't overwrite attempts that are already failed/canceled for another reason
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if index_attempt and index_attempt.status in [
IndexingStatus.CANCELED,
IndexingStatus.FAILED,
]:
logger.info(
f"Attempt {index_attempt_id} is already failed/canceled, skipping marking as failed."
)
raise e
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
raise e
finally:
memory_tracer.stop()
def reissue_old_batches(
batch_storage: DocumentBatchStorage,
index_attempt_id: int,
cc_pair_id: int,
tenant_id: str,
app: Celery,
most_recent_attempt: IndexAttempt | None,
) -> tuple[int, int]:
# When loading from a checkpoint, we need to start new docprocessing tasks
# tied to the new index attempt for any batches left over in the file store
old_batches = batch_storage.get_all_batches_for_cc_pair()
batch_storage.update_old_batches_to_new_index_attempt(old_batches)
for batch_id in old_batches:
logger.info(
f"Re-issuing docprocessing task for batch {batch_id} for index attempt {index_attempt_id}"
)
path_info = batch_storage.extract_path_info(batch_id)
if path_info is None:
logger.warning(
f"Could not extract path info from batch {batch_id}, skipping"
)
continue
if path_info.cc_pair_id != cc_pair_id:
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs={
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": path_info.batch_num, # use same batch num as previously
},
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
# resume from the batch num of the last attempt. This should be one more
# than the last batch created by docfetching regardless of whether the batch
# is still in the filestore waiting for processing or not.
last_batch_num = len(old_batches) + recent_batches
logger.info(
f"Starting from batch {last_batch_num} due to "
f"re-issued batches: {old_batches}, completed batches: {recent_batches}"
)
return len(old_batches), recent_batches

View File

@@ -725,9 +725,7 @@ def stream_chat_message_objects(
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids = new_msg_req.user_file_ids or []
@@ -1012,6 +1010,7 @@ def stream_chat_message_objects(
tools=tools,
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(

View File

@@ -35,6 +35,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
# Controls whether users can use User Knowledge (personal documents) in assistants
DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true"
# Controls whether to allow admin query history reports with:
# 1. associated user emails
# 2. anonymized user emails
@@ -118,6 +121,8 @@ OAUTH_CLIENT_SECRET = (
os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET"))
or ""
)
# OpenID Connect configuration URL for Okta Profile Tool and other OIDC integrations
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL") or ""
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
@@ -308,25 +313,40 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT = 6
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
env_value = os.environ.get("CELERY_WORKER_DOCPROCESSING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_INDEXING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
env_value = str(CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = (
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1
try:
env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_DOCFETCHING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_DOCFETCHING_CONCURRENCY = (
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 1024
VESPA_SYNC_MAX_TASKS = 8192
DB_YIELD_PER_DEFAULT = 64
@@ -341,6 +361,12 @@ POLL_CONNECTOR_OFFSET = 30 # Minutes overlap between poll windows
# only very select connectors are enabled and admins cannot add other connector types
ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# If set to true, curators can only access and edit assistants that they created
CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS = (
os.environ.get("CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS", "").lower()
== "true"
)
# Some calls to get information on expert users are quite costly especially with rate limiting
# Since experts are not used in the actual user experience, currently it is turned off
# for some connectors
@@ -450,6 +476,11 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Default size threshold for SharePoint files (20MB)
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -588,6 +619,17 @@ AVERAGE_SUMMARY_EMBEDDINGS = (
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
#####
# Tool Configs
#####
OKTA_PROFILE_TOOL_ENABLED = (
os.environ.get("OKTA_PROFILE_TOOL_ENABLED", "").lower() == "true"
)
# API token for SSWS auth to Okta Admin API. If set, Users API will be used to enrich profile.
OKTA_API_TOKEN = os.environ.get("OKTA_API_TOKEN") or ""
#####
# Miscellaneous
#####

View File

@@ -93,6 +93,9 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
BING_API_KEY = os.environ.get("BING_API_KEY") or None
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)

View File

@@ -65,7 +65,8 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
@@ -121,6 +122,8 @@ CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
# hard termination should always fire first if the connector is hung
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
# Heartbeat interval for indexing worker liveness detection
INDEXING_WORKER_HEARTBEAT_INTERVAL = 30 # seconds
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
@@ -331,9 +334,12 @@ class OnyxCeleryQueues:
CSV_GENERATION = "csv_generation"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
USER_FILES_INDEXING = "user_files_indexing"
# Document processing pipeline queue
DOCPROCESSING = "docprocessing"
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
# Monitoring queue
MONITORING = "monitoring"
@@ -464,7 +470,11 @@ class OnyxCeleryTask:
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
# New split indexing tasks
CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task"
DOCPROCESSING_TASK = "docprocessing_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"

View File

@@ -34,7 +34,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -281,30 +280,28 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# TODO: Refactor to avoid direct DB access in connector
# This will require broader refactoring across the codebase
with get_session_with_current_tenant() as db_session:
image_section, _ = store_image_and_create_section(
db_session=db_session,
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
image_section, _ = store_image_and_create_section(
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logger.exception(f"Error processing image {key}")
continue

View File

@@ -1,3 +1,17 @@
"""
# README (notes on Confluence pagination):
We've noticed that the `search/users` and `users/memberof` endpoints for Confluence Cloud use offset-based pagination as
opposed to cursor-based. We also know that page-retrieval uses cursor-based pagination.
Our default pagination strategy right now for cloud is to assume cursor-based.
However, if you notice that a cloud API is not being properly paginated (i.e., if the `_links.next` is not appearing in the
returned payload), then you can force offset-based pagination.
# TODO (@raunakab)
We haven't explored all of the cloud APIs' pagination strategies. @raunakab take time to go through this and figure them out.
"""
import json
import time
from collections.abc import Callable
@@ -46,15 +60,13 @@ _REPLACEMENT_EXPANSIONS = "body.view.value"
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
_DEFAULT_PAGINATION_LIMIT = 1000
class ConfluenceRateLimitError(Exception):
pass
_DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence:
"""
This is a custom Confluence class that:
@@ -463,6 +475,7 @@ class OnyxConfluence:
limit: int | None = None,
# Called with the next url to use to get the next page
next_page_callback: Callable[[str], None] | None = None,
force_offset_pagination: bool = False,
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -548,14 +561,32 @@ class OnyxConfluence:
)
raise e
# yield the results individually
# Yield the results individually.
results = cast(list[dict[str, Any]], next_response.get("results", []))
# make sure we don't update the start by more than the amount
# Note 1:
# Make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
#
# Note 2:
# We specifically perform manual yielding (i.e., `for x in xs: yield x`) as opposed to using a `yield from xs`
# because we *have to call the `next_page_callback`* prior to yielding the last element!
#
# If we did:
#
# ```py
# yield from results
# if next_page_callback:
# next_page_callback(url_suffix)
# ```
#
# then the logic would fail since the iterator would finish (and the calling scope would exit out of its driving
# loop) prior to the callback being called.
old_url_suffix = url_suffix
updated_start = get_start_param_from_url(old_url_suffix)
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
@@ -571,6 +602,12 @@ class OnyxConfluence:
)
# notify the caller of the new url
next_page_callback(url_suffix)
elif force_offset_pagination and i == len(results) - 1:
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
yield result
# we've observed that Confluence sometimes returns a next link despite giving
@@ -668,7 +705,9 @@ class OnyxConfluence:
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(url, limit):
for user_result in self._paginate_url(
url, limit, force_offset_pagination=True
):
# Example response:
# {
# 'user': {
@@ -758,7 +797,7 @@ class OnyxConfluence:
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
yield from self._paginate_url(url, limit, force_offset_pagination=True)
def paginated_groups_retrieval(
self,

View File

@@ -23,7 +23,6 @@ from onyx.configs.app_configs import (
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
@@ -224,19 +223,17 @@ def _process_image_attachment(
"""Process an image attachment by saving it without generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
section, file_name = store_image_and_create_section(
db_session=db_session,
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)

View File

@@ -25,9 +25,32 @@ TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def batched_doc_ids(
checkpoint_connector_generator: CheckpointOutput[CT],
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if len(batch) >= batch_size:
yield batch
batch = set()
if len(batch) > 0:
yield batch
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
Wraps a CheckpointOutput generator to give things back in a more digestible format,
specifically for Document outputs.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
@@ -131,7 +154,7 @@ class ConnectorRunner(Generic[CT]):
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
if document is not None and isinstance(document, Document):
self.doc_batch.append(document)
if failure is not None:

View File

@@ -5,8 +5,6 @@ from pathlib import Path
from typing import Any
from typing import IO
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
@@ -18,7 +16,6 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -27,12 +24,12 @@ from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _create_image_section(
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
@@ -58,7 +55,6 @@ def _create_image_section(
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_id=file_id,
display_name=display_name,
@@ -77,7 +73,7 @@ def _process_file(
file: IO[Any],
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
file_type: str | None,
) -> list[Document]:
"""
Process a file and return a list of Documents.
@@ -125,7 +121,6 @@ def _process_file(
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=file_id,
display_name=title,
)
@@ -155,6 +150,7 @@ def _process_file(
file=file,
file_name=file_name,
pdf_pass=pdf_pass,
content_type=file_type,
)
# Each file may have file-specific ONYX_METADATA https://docs.onyx.app/connectors/file
@@ -196,7 +192,6 @@ def _process_file(
try:
image_section, stored_file_name = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=file_id,
display_name=f"{title} - image {idx}",
idx=idx,
@@ -230,18 +225,25 @@ class LocalFileConnector(LoadConnector):
"""
Connector that reads files from Postgres and yields Documents, including
embedded image extraction without summarization.
file_locations are S3/Filestore UUIDs
file_names are the names of the files
"""
# Note: file_names is a required parameter, but should not break backwards compatibility.
# If add_file_names migration is not run, old file connector configs will not have file_names.
# file_names is only used for display purposes in the UI and file_locations is used as a fallback.
def __init__(
self,
file_locations: list[Path | str],
zip_metadata: dict[str, Any],
file_names: list[str] | None = None,
zip_metadata: dict[str, Any] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [str(loc) for loc in file_locations]
self.batch_size = batch_size
self.pdf_pass: str | None = None
self.zip_metadata = zip_metadata
self.zip_metadata = zip_metadata or {}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
@@ -260,41 +262,40 @@ class LocalFileConnector(LoadConnector):
"""
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
for file_id in self.file_locations:
file_store = get_default_file_store(db_session)
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(
f"No file record found for '{file_id}' in PG; skipping."
)
continue
for file_id in self.file_locations:
file_store = get_default_file_store()
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
continue
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
)
documents.extend(new_docs)
metadata = self._get_file_metadata(file_record.display_name)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
file_type=file_record.file_type,
)
documents.extend(new_docs)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents
if __name__ == "__main__":
connector = LocalFileConnector(
file_locations=[os.environ["TEST_FILE"]], zip_metadata={}
file_locations=[os.environ["TEST_FILE"]],
file_names=[os.environ["TEST_FILE"]],
zip_metadata={},
)
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
doc_batches = connector.load_from_state()

View File

@@ -1,5 +1,4 @@
import copy
import time
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
@@ -17,17 +16,22 @@ from github.Issue import Issue
from github.NamedUser import NamedUser
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from github.Requester import Requester
from pydantic import BaseModel
from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.constants import DocumentSource
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.github.models import SerializedRepository
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
from onyx.connectors.github.utils import deserialize_repository
from onyx.connectors.github.utils import get_external_access_permission
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
@@ -46,17 +50,7 @@ CURSOR_LOG_FREQUENCY = 50
_MAX_NUM_RATE_LIMIT_RETRIES = 5
ONE_DAY = timedelta(days=1)
def _sleep_after_rate_limit_exception(github_client: Github) -> None:
sleep_time = github_client.get_rate_limit().core.reset.replace(
tzinfo=timezone.utc
) - datetime.now(tz=timezone.utc)
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
time.sleep(sleep_time.seconds)
SLIM_BATCH_SIZE = 100
# Cases
# X (from start) standard run, no fallback to cursor-based pagination
# X (from start) standard run errors, fallback to cursor-based pagination
@@ -72,6 +66,10 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
# checkpoint progress (no infinite loop)
class DocMetadata(BaseModel):
repo: str
def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
if "_PaginatedList__nextUrl" in pag_list.__dict__:
return "_PaginatedList__nextUrl"
@@ -190,7 +188,7 @@ def _get_batch_rate_limited(
getattr(obj, "raw_data")
yield from objs
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
sleep_after_rate_limit_exception(github_client)
yield from _get_batch_rate_limited(
git_objs,
page_num,
@@ -232,12 +230,17 @@ def _get_userinfo(user: NamedUser) -> dict[str, str]:
}
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
def _convert_pr_to_document(
pull_request: PullRequest, repo_external_access: ExternalAccess | None
) -> Document:
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
doc_metadata = DocMetadata(repo=repo_name)
return Document(
id=pull_request.html_url,
sections=[
TextSection(link=pull_request.html_url, text=pull_request.body or "")
],
external_access=repo_external_access,
source=DocumentSource.GITHUB,
semantic_identifier=f"{pull_request.number}: {pull_request.title}",
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -248,6 +251,8 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
if pull_request.updated_at
else None
),
# this metadata is used in perm sync
doc_metadata=doc_metadata.model_dump(),
metadata={
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
@@ -301,14 +306,21 @@ def _fetch_issue_comments(issue: Issue) -> str:
return "\nComment: ".join(comment.body for comment in comments)
def _convert_issue_to_document(issue: Issue) -> Document:
def _convert_issue_to_document(
issue: Issue, repo_external_access: ExternalAccess | None
) -> Document:
repo_name = issue.repository.full_name if issue.repository else ""
doc_metadata = DocMetadata(repo=repo_name)
return Document(
id=issue.html_url,
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
source=DocumentSource.GITHUB,
external_access=repo_external_access,
semantic_identifier=f"{issue.number}: {issue.title}",
# updated_at is UTC time but is timezone unaware
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
# this metadata is used in perm sync
doc_metadata=doc_metadata.model_dump(),
metadata={
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
@@ -343,18 +355,6 @@ def _convert_issue_to_document(issue: Issue) -> Document:
)
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
@@ -394,7 +394,7 @@ def make_cursor_url_callback(
return cursor_url_callback
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
@@ -423,7 +423,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
)
return None
def _get_github_repo(
def get_github_repo(
self, github_client: Github, attempt_num: int = 0
) -> Repository.Repository:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
@@ -434,10 +434,10 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
sleep_after_rate_limit_exception(github_client)
return self.get_github_repo(github_client, attempt_num + 1)
def _get_github_repos(
def get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
@@ -465,10 +465,10 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
return repos
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repos(github_client, attempt_num + 1)
sleep_after_rate_limit_exception(github_client)
return self.get_github_repos(github_client, attempt_num + 1)
def _get_all_repos(
def get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
@@ -487,8 +487,8 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
user = github_client.get_user(self.repo_owner)
return list(user.get_repos())
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_all_repos(github_client, attempt_num + 1)
sleep_after_rate_limit_exception(github_client)
return self.get_all_repos(github_client, attempt_num + 1)
def _pull_requests_func(
self, repo: Repository.Repository
@@ -509,6 +509,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
@@ -521,13 +522,13 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
repos = self.get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
repos = [self.get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
repos = self.get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
@@ -547,28 +548,15 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
if checkpoint.cached_repo is None:
raise ValueError("No repo saved in checkpoint")
# Try to access the requester - different PyGithub versions may use different attribute names
try:
# Try direct access to a known attribute name first
if hasattr(self.github_client, "_requester"):
requester = self.github_client._requester
elif hasattr(self.github_client, "_Github__requester"):
requester = self.github_client._Github__requester
else:
# If we can't find the requester attribute, we need to fall back to recreating the repo
raise AttributeError("Could not find requester attribute")
repo = checkpoint.cached_repo.to_Repository(requester)
except Exception as e:
# If all else fails, re-fetch the repo directly
logger.warning(
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
)
repo_id = checkpoint.cached_repo.id
repo = self.github_client.get_repo(repo_id)
# Deserialize the repository from the checkpoint
repo = deserialize_repository(checkpoint.cached_repo, self.github_client)
cursor_url_callback = make_cursor_url_callback(checkpoint)
repo_external_access: ExternalAccess | None = None
if include_permissions:
repo_external_access = get_external_access_permission(
repo, self.github_client
)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
@@ -603,7 +591,9 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
):
continue
try:
yield _convert_pr_to_document(cast(PullRequest, pr))
yield _convert_pr_to_document(
cast(PullRequest, pr), repo_external_access
)
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
@@ -653,6 +643,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
self.github_client,
)
)
logger.info(f"Fetched {len(issue_batch)} issues for repo: {repo.name}")
checkpoint.curr_page += 1
done_with_issues = False
num_issues = 0
@@ -678,7 +669,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
continue
try:
yield _convert_issue_to_document(issue)
yield _convert_issue_to_document(issue, repo_external_access)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
@@ -715,12 +706,16 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.reset()
logger.info(f"{len(checkpoint.cached_repo_ids)} repos remaining")
if checkpoint.cached_repo_ids:
logger.info(
f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})"
)
else:
logger.info("No more repos remaining")
return checkpoint
@override
def load_from_checkpoint(
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
@@ -741,7 +736,32 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
adjusted_start_datetime = epoch
return self._fetch_from_github(
checkpoint, start=adjusted_start_datetime, end=end_datetime
checkpoint,
start=adjusted_start_datetime,
end=end_datetime,
include_permissions=include_permissions,
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=False
)
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=True
)
def validate_connector_settings(self) -> None:
@@ -775,6 +795,9 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
logger.info(
f"Successfully accessed repository: {self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
@@ -882,7 +905,6 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
if __name__ == "__main__":
import os
from onyx.connectors.connector_runner import ConnectorRunner
# Initialize the connector
connector = GithubConnector(
@@ -893,6 +915,12 @@ if __name__ == "__main__":
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
)
if connector.github_client:
get_external_access_permission(
connector.get_github_repos(connector.github_client).pop(),
connector.github_client,
)
# Create a time range from epoch to now
end_time = datetime.now(timezone.utc)
start_time = datetime.fromtimestamp(0, tz=timezone.utc)

View File

@@ -0,0 +1,17 @@
from typing import Any
from github import Repository
from github.Requester import Requester
from pydantic import BaseModel
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)

View File

@@ -0,0 +1,25 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from github import Github
from onyx.utils.logger import setup_logger
logger = setup_logger()
def sleep_after_rate_limit_exception(github_client: Github) -> None:
"""
Sleep until the GitHub rate limit resets.
Args:
github_client: The GitHub client that hit the rate limit
"""
sleep_time = github_client.get_rate_limit().core.reset.replace(
tzinfo=timezone.utc
) - datetime.now(tz=timezone.utc)
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
time.sleep(sleep_time.total_seconds())

View File

@@ -0,0 +1,63 @@
from collections.abc import Callable
from typing import cast
from github import Github
from github.Repository import Repository
from onyx.access.models import ExternalAccess
from onyx.connectors.github.models import SerializedRepository
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
logger = setup_logger()
def get_external_access_permission(
repo: Repository, github_client: Github
) -> ExternalAccess:
"""
Get the external access permission for a repository.
This functionality requires Enterprise Edition.
"""
# Check if EE is enabled
if not global_version.is_ee_version():
# For the MIT version, return an empty ExternalAccess (private document)
return ExternalAccess.empty()
# Fetch the EE implementation
ee_get_external_access_permission = cast(
Callable[[Repository, Github, bool], ExternalAccess],
fetch_versioned_implementation(
"onyx.external_permissions.github.utils",
"get_external_access_permission",
),
)
return ee_get_external_access_permission(repo, github_client, True)
def deserialize_repository(
cached_repo: SerializedRepository, github_client: Github
) -> Repository:
"""
Deserialize a SerializedRepository back into a Repository object.
"""
# Try to access the requester - different PyGithub versions may use different attribute names
try:
# Try to get the requester using getattr to avoid linter errors
requester = getattr(github_client, "_requester", None)
if requester is None:
requester = getattr(github_client, "_Github__requester", None)
if requester is None:
# If we can't find the requester attribute, we need to fall back to recreating the repo
raise AttributeError("Could not find requester attribute")
return cached_repo.to_Repository(requester)
except Exception as e:
# If all else fails, re-fetch the repo directly
logger.warning(
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
)
repo_id = cached_repo.id
return github_client.get_repo(repo_id)

View File

@@ -1,6 +1,10 @@
import copy
import json
import os
import sys
import threading
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
@@ -1374,3 +1378,139 @@ class GoogleDriveConnector(
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint:
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
def get_credentials_from_env(email: str, oauth: bool) -> dict:
if oauth:
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
else:
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
refried_credential_string = json.dumps(json.loads(raw_credential_string))
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
cred_key = (
DB_CREDENTIALS_DICT_TOKEN_KEY
if oauth
else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
)
return {
cred_key: refried_credential_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: GoogleDriveCheckpoint | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> CheckpointOutput[GoogleDriveCheckpoint]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
def yield_all_docs_from_checkpoint_connector(
connector: GoogleDriveConnector,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> Iterator[Document | ConnectorFailure]:
num_iterations = 0
checkpoint = connector.build_dummy_checkpoint()
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
yield failure
if document is not None:
yield document
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > 100_000:
raise RuntimeError("Too many iterations. Infinite loop?")
if __name__ == "__main__":
import time
creds = get_credentials_from_env(
os.environ["GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL"], False
)
connector = GoogleDriveConnector(
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=True,
specific_user_emails=None,
)
connector.load_credentials(creds)
max_fsize = 0
biggest_fsize = 0
num_errors = 0
start_time = time.time()
with open("stats.txt", "w") as f:
for num, doc_or_failure in enumerate(
yield_all_docs_from_checkpoint_connector(connector, 0, time.time())
):
if num % 200 == 0:
f.write(f"Processed {num} files\n")
f.write(f"Max file size: {max_fsize/1000_000:.2f} MB\n")
f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n")
f.write(f"Docs per minute: {num/(time.time() - start_time)*60:.2f}\n")
biggest_fsize = max(biggest_fsize, max_fsize)
max_fsize = 0
if isinstance(doc_or_failure, Document):
max_fsize = max(max_fsize, sys.getsizeof(doc_or_failure))
elif isinstance(doc_or_failure, ConnectorFailure):
num_errors += 1
print(f"Num errors: {num_errors}")
print(f"Biggest file size: {biggest_fsize/1000_000:.2f} MB")
print(f"Time taken: {time.time() - start_time:.2f} seconds")

View File

@@ -29,7 +29,6 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
@@ -143,17 +142,15 @@ def _download_and_extract_sections_basic(
# Store images for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
@@ -216,16 +213,14 @@ def _download_and_extract_sections_basic(
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections

View File

@@ -12,7 +12,6 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.html_utils import web_html_cleanup
@@ -68,10 +67,7 @@ class GoogleSitesConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)
file_content_io = get_default_file_store().read_file(self.zip_path, mode="b")
# load the HTML files
files = load_files_from_zip(file_content_io)

View File

@@ -11,6 +11,7 @@ from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.db.enums import IndexModelStatus
from onyx.utils.text_processing import make_url_compatible
@@ -182,6 +183,7 @@ class DocumentBase(BaseModel):
# only filled in EE for connectors w/ permission sync enabled
external_access: ExternalAccess | None = None
doc_metadata: dict[str, Any] | None = None
def get_title_for_document_index(
self,
@@ -363,6 +365,10 @@ class ConnectorFailure(BaseModel):
return values
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class OnyxMetadata(BaseModel):
# Note that doc_id cannot be overriden here as it may cause issues
# with the display functionalities in the UI. Ask @chris if clarification is needed.
@@ -373,3 +379,24 @@ class OnyxMetadata(BaseModel):
secondary_owners: list[BasicExpertInfo] | None = None
doc_updated_at: datetime | None = None
title: str | None = None
class DocExtractionContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
doc_extraction_complete_batch_num: int | None
class DocIndexingContext(BaseModel):
batches_done: int
total_failures: int
net_doc_change: int
total_chunks: int

View File

@@ -267,7 +267,7 @@ class NotionConnector(LoadConnector, PollConnector):
result = ""
for prop_name, prop in properties.items():
if not prop:
if not prop or not isinstance(prop, dict):
continue
try:

View File

@@ -1,5 +1,6 @@
import csv
import gc
import json
import os
import sys
import tempfile
@@ -28,8 +29,12 @@ from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -38,27 +43,27 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
"Opportunity": {
"Account": "account",
ACCOUNT_OBJECT_TYPE: "account",
"FiscalQuarter": "fiscal_quarter",
"FiscalYear": "fiscal_year",
"IsClosed": "is_closed",
"Name": "name",
NAME_FIELD: "name",
"StageName": "stage_name",
"Type": "type",
"Amount": "amount",
"CloseDate": "close_date",
"Probability": "probability",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
MODIFIED_FIELD: "last_modified_date",
},
"Contact": {
"Account": "account",
ACCOUNT_OBJECT_TYPE: "account",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
MODIFIED_FIELD: "last_modified_date",
},
}
@@ -74,19 +79,77 @@ class SalesforceConnectorContext:
parent_to_child_types: dict[str, set[str]] = {} # map from parent to child types
child_to_parent_types: dict[str, set[str]] = {} # map from child to parent types
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
type_to_queryable_fields: dict[str, list[str]] = {}
type_to_queryable_fields: dict[str, set[str]] = {}
prefix_to_type: dict[str, str] = {} # infer the object type of an id immediately
parent_to_child_relationships: dict[str, set[str]] = (
{}
) # map from parent to child relationships
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = (
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = (
{}
) # map from relationship to queryable fields
parent_child_names_to_relationships: dict[str, str] = {}
def _extract_fields_and_associations_from_config(
config: dict[str, Any], object_type: str
) -> tuple[list[str] | None, dict[str, list[str]]]:
"""
Extract fields and associations for a specific object type from custom config.
Returns:
tuple of (fields_list, associations_dict)
- fields_list: List of fields to query, or None if not specified (use all)
- associations_dict: Dict mapping association names to their config
"""
if object_type not in config:
return None, {}
obj_config = config[object_type]
fields = obj_config.get("fields")
associations = obj_config.get("associations", {})
return fields, associations
def _validate_custom_query_config(config: dict[str, Any]) -> None:
"""
Validate the structure of the custom query configuration.
"""
for object_type, obj_config in config.items():
if not isinstance(obj_config, dict):
raise ValueError(
f"top level object {object_type} must be mapped to a dictionary"
)
# Check if fields is a list when present
if "fields" in obj_config:
if not isinstance(obj_config["fields"], list):
raise ValueError("if fields key exists, value must be a list")
for v in obj_config["fields"]:
if not isinstance(v, str):
raise ValueError(f"if fields list value {v} is not a string")
# Check if associations is a dict when present
if "associations" in obj_config:
if not isinstance(obj_config["associations"], dict):
raise ValueError(
"if associations key exists, value must be a dictionary"
)
for assoc_name, assoc_fields in obj_config["associations"].items():
if not isinstance(assoc_fields, list):
raise ValueError(
f"associations list value {assoc_fields} for key {assoc_name} is not a list"
)
for v in assoc_fields:
if not isinstance(v, str):
raise ValueError(
f"if associations list value {v} is not a string"
)
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
"""Approach outline
@@ -134,14 +197,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
batch_size: int = INDEX_BATCH_SIZE,
requested_objects: list[str] = [],
custom_query_config: str | None = None,
) -> None:
self.batch_size = batch_size
self._sf_client: OnyxSalesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else _DEFAULT_PARENT_OBJECT_TYPES
)
# Validate and store custom query config
if custom_query_config:
config_json = json.loads(custom_query_config)
self.custom_query_config: dict[str, Any] | None = config_json
# If custom query config is provided, use the object types from it
self.parent_object_list = list(config_json.keys())
else:
self.custom_query_config = None
# Use the traditional requested_objects approach
self.parent_object_list = (
[obj.strip().capitalize() for obj in requested_objects]
if requested_objects
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(
self,
@@ -187,7 +261,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
@staticmethod
def _download_object_csvs(
all_types_to_filter: dict[str, bool],
queryable_fields_by_type: dict[str, list[str]],
queryable_fields_by_type: dict[str, set[str]],
directory: str,
sf_client: OnyxSalesforce,
start: SecondsSinceUnixEpoch | None = None,
@@ -325,9 +399,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# all_types.update(child_types.keys())
# # Always want to make sure user is grabbed for permissioning purposes
# all_types.add("User")
# all_types.add(USER_OBJECT_TYPE)
# # Always want to make sure account is grabbed for reference purposes
# all_types.add("Account")
# all_types.add(ACCOUNT_OBJECT_TYPE)
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -351,7 +425,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# all_types.update(child_types)
# # Always want to make sure user is grabbed for permissioning purposes
# all_types.add("User")
# all_types.add(USER_OBJECT_TYPE)
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -364,7 +438,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateDocumentsOutput:
type_to_processed: dict[str, int] = {}
logger.info("_fetch_from_salesforce starting.")
logger.info("_fetch_from_salesforce starting (full sync).")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
@@ -548,7 +622,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateDocumentsOutput:
type_to_processed: dict[str, int] = {}
logger.info("_fetch_from_salesforce starting.")
logger.info("_fetch_from_salesforce starting (delta sync).")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
@@ -677,7 +751,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
try:
last_modified_by_id = record["LastModifiedById"]
user_record = self.sf_client.query_object(
"User", last_modified_by_id, ctx.type_to_queryable_fields
USER_OBJECT_TYPE,
last_modified_by_id,
ctx.type_to_queryable_fields,
)
if user_record:
primary_owner = BasicExpertInfo.from_dict(user_record)
@@ -792,7 +868,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = (
{}
) # for a given object, the fields reference parent objects
type_to_queryable_fields: dict[str, list[str]] = {}
type_to_queryable_fields: dict[str, set[str]] = {}
prefix_to_type: dict[str, str] = {}
parent_to_child_relationships: dict[str, set[str]] = (
@@ -802,15 +878,13 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# relationship keys are formatted as "parent__relationship"
# we have to do this because relationship names are not unique!
# values are a dict of relationship names to a list of queryable fields
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = {}
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = {}
parent_child_names_to_relationships: dict[str, str] = {}
full_sync = False
if start is None and end is None:
full_sync = True
full_sync = start is None and end is None
# Step 1 - make a list of all the types to download (parent + direct child + "User")
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
# prefixes = {}
global_description = sf_client.describe()
@@ -831,16 +905,63 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
for parent_type in parent_types:
# parent_onyx_sf_type = OnyxSalesforceType(parent_type, sf_client)
type_to_queryable_fields[parent_type] = (
sf_client.get_queryable_fields_by_type(parent_type)
)
child_types_working = sf_client.get_children_of_sf_type(parent_type)
logger.debug(
f"Found {len(child_types_working)} child types for {parent_type}"
)
custom_fields: list[str] | None = []
associations_config: dict[str, list[str]] | None = None
# parent_to_child_relationships[parent_type] = child_types_working
# Set queryable fields for parent type
if self.custom_query_config:
custom_fields, associations_config = (
_extract_fields_and_associations_from_config(
self.custom_query_config, parent_type
)
)
custom_fields = custom_fields or []
# Get custom fields for parent type
field_set = set(custom_fields)
# these are expected and used during doc conversion
field_set.add(NAME_FIELD)
field_set.add(MODIFIED_FIELD)
# Use only the specified fields
type_to_queryable_fields[parent_type] = field_set
logger.info(f"Using custom fields for {parent_type}: {field_set}")
else:
# Use all queryable fields
type_to_queryable_fields[parent_type] = (
sf_client.get_queryable_fields_by_type(parent_type)
)
logger.info(f"Using all fields for {parent_type}")
child_types_all = sf_client.get_children_of_sf_type(parent_type)
logger.debug(f"Found {len(child_types_all)} child types for {parent_type}")
logger.debug(f"child types: {child_types_all}")
child_types_working = child_types_all.copy()
if associations_config is not None:
child_types_working = {
k: v for k, v in child_types_all.items() if k in associations_config
}
any_not_found = False
for k in associations_config:
if k not in child_types_working:
any_not_found = True
logger.warning(f"Association {k} not found in {parent_type}")
if any_not_found:
queryable_fields = sf_client.get_queryable_fields_by_type(
parent_type
)
raise RuntimeError(
f"Associations {associations_config} not found in {parent_type} "
"make sure your parent-child associations are in the right order"
# f"with child objects {child_types_all}"
# f" and fields {queryable_fields}"
)
parent_to_child_relationships[parent_type] = set()
parent_to_child_types[parent_type] = set()
parent_to_relationship_queryable_fields[parent_type] = {}
for child_type, child_relationship in child_types_working.items():
child_type = cast(str, child_type)
@@ -848,8 +969,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
# map parent name to child name
if parent_type not in parent_to_child_types:
parent_to_child_types[parent_type] = set()
parent_to_child_types[parent_type].add(child_type)
# reverse map child name to parent name
@@ -858,19 +977,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
child_to_parent_types[child_type].add(parent_type)
# map parent name to child relationship
if parent_type not in parent_to_child_relationships:
parent_to_child_relationships[parent_type] = set()
parent_to_child_relationships[parent_type].add(child_relationship)
# map relationship to queryable fields of the target table
queryable_fields = sf_client.get_queryable_fields_by_type(child_type)
if config_fields := (
associations_config and associations_config.get(child_type)
):
field_set = set(config_fields)
# these are expected and used during doc conversion
field_set.add(NAME_FIELD)
field_set.add(MODIFIED_FIELD)
queryable_fields = field_set
else:
queryable_fields = sf_client.get_queryable_fields_by_type(
child_type
)
if child_relationship in parent_to_relationship_queryable_fields:
raise RuntimeError(f"{child_relationship=} already exists")
if parent_type not in parent_to_relationship_queryable_fields:
parent_to_relationship_queryable_fields[parent_type] = {}
parent_to_relationship_queryable_fields[parent_type][
child_relationship
] = queryable_fields
@@ -894,14 +1019,22 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
all_types.update(child_types)
# NOTE(rkuo): should this be an implicit parent type?
all_types.add("User") # Always add User for permissioning purposes
all_types.add("Account") # Always add Account for reference purposes
all_types.add(USER_OBJECT_TYPE) # Always add User for permissioning purposes
all_types.add(ACCOUNT_OBJECT_TYPE) # Always add Account for reference purposes
logger.info(f"All object types: num={len(all_types)} list={all_types}")
# Ensure User and Account have queryable fields if they weren't already processed
essential_types = [USER_OBJECT_TYPE, ACCOUNT_OBJECT_TYPE]
for essential_type in essential_types:
if essential_type not in type_to_queryable_fields:
type_to_queryable_fields[essential_type] = (
sf_client.get_queryable_fields_by_type(essential_type)
)
# 1.1 - Detect all fields in child types which reference a parent type.
# build dicts to detect relationships between parent and child
for child_type in child_types:
for child_type in child_types.union(essential_types):
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
parent_reference_fields = sf_client.get_parent_reference_fields(
child_type, parent_types
@@ -992,7 +1125,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
query_result = self.sf_client.safe_query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
@@ -1003,6 +1136,32 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
yield doc_metadata_list
def validate_connector_settings(self) -> None:
"""
Validate that the Salesforce credentials and connector settings are correct.
Specifically checks that we can make an authenticated request to Salesforce.
"""
try:
# Attempt to fetch a small batch of objects (arbitrary endpoint) to verify credentials
self.sf_client.describe()
except Exception as e:
raise ConnectorMissingCredentialError(
"Failed to validate Salesforce credentials. Please check your"
f"credentials and try again. Error: {e}"
)
if self.custom_query_config:
try:
_validate_custom_query_config(self.custom_query_config)
except Exception as e:
raise ConnectorMissingCredentialError(
"Failed to validate Salesforce custom query config. Please check your"
f"config and try again. Error: {e}"
)
logger.info("Salesforce credentials validated successfully.")
# @override
# def load_from_checkpoint(
# self,
@@ -1032,7 +1191,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
if __name__ == "__main__":
connector = SalesforceConnector(requested_objects=["Account"])
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE])
connector.load_credentials(
{

View File

@@ -10,6 +10,8 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
@@ -140,7 +142,7 @@ def _extract_primary_owner(
first_name=user_data.get("FirstName"),
last_name=user_data.get("LastName"),
email=user_data.get("Email"),
display_name=user_data.get("Name"),
display_name=user_data.get(NAME_FIELD),
)
# Check if all fields are None
@@ -166,8 +168,8 @@ def convert_sf_query_result_to_doc(
"""Generates a yieldable Document from query results"""
base_url = f"https://{sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(record["LastModifiedDate"])
extracted_semantic_identifier = record.get("Name", "Unknown Object")
extracted_doc_updated_at = time_str_to_utc(record[MODIFIED_FIELD])
extracted_semantic_identifier = record.get(NAME_FIELD, "Unknown Object")
sections = [_extract_section(record, f"{base_url}/{record_id}")]
for child_record_key, child_record in child_records.items():
@@ -205,8 +207,8 @@ def convert_sf_object_to_doc(
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
base_url = f"https://{sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_doc_updated_at = time_str_to_utc(object_dict[MODIFIED_FIELD])
extracted_semantic_identifier = object_dict.get(NAME_FIELD, "Unknown Object")
sections = [_extract_section(sf_object.data, f"{base_url}/{sf_object.id}")]
for id in sf_db.get_child_ids(sf_object.id):

View File

@@ -1,18 +1,31 @@
import time
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from simple_salesforce.exceptions import SalesforceRefusedRequest
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_OBJECTS
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_PREFIXES
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_SUFFIXES
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
"""Check if an exception is a Salesforce rate limit error."""
return isinstance(
exception, SalesforceRefusedRequest
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
class OnyxSalesforce(Salesforce):
SOQL_MAX_SUBQUERIES = 20
@@ -47,17 +60,59 @@ class OnyxSalesforce(Salesforce):
return True
for suffix in SALESFORCE_BLACKLISTED_SUFFIXES:
if object_type_lower.endswith(prefix):
if object_type_lower.endswith(suffix):
return True
return False
@retry_builder(
tries=5,
delay=20,
backoff=1.5,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def safe_query(self, query: str, **kwargs: Any) -> dict[str, Any]:
"""Wrapper around the original query method with retry logic and rate limiting."""
try:
return super().query(query, **kwargs)
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for query: {query[:100]}..."
)
# Add additional delay for rate limit errors
time.sleep(5)
raise
@retry_builder(
tries=5,
delay=20,
backoff=1.5,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def safe_query_all(self, query: str, **kwargs: Any) -> dict[str, Any]:
"""Wrapper around the original query_all method with retry logic and rate limiting."""
try:
return super().query_all(query, **kwargs)
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for query_all: {query[:100]}..."
)
# Add additional delay for rate limit errors
time.sleep(5)
raise
@staticmethod
def _make_child_objects_by_id_query(
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> str:
"""Returns a SOQL query given the object id, type and child relationships.
@@ -93,13 +148,13 @@ class OnyxSalesforce(Salesforce):
self,
object_type: str,
object_id: str,
type_to_queryable_fields: dict[str, list[str]],
type_to_queryable_fields: dict[str, set[str]],
) -> dict[str, Any] | None:
record: dict[str, Any] = {}
queryable_fields = type_to_queryable_fields[object_type]
query = get_object_by_id_query(object_id, object_type, queryable_fields)
result = self.query(query)
result = self.safe_query(query)
if not result:
return None
@@ -117,7 +172,7 @@ class OnyxSalesforce(Salesforce):
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> dict[str, dict[str, Any]]:
"""There's a limit on the number of subqueries we can put in a single query."""
child_records: dict[str, dict[str, Any]] = {}
@@ -151,7 +206,7 @@ class OnyxSalesforce(Salesforce):
)
try:
result = self.query(query)
result = self.safe_query(query)
except Exception:
logger.exception(f"Query failed: {query=}")
else:
@@ -189,15 +244,30 @@ class OnyxSalesforce(Salesforce):
return child_records
@retry_builder(
tries=3,
delay=1,
backoff=2,
exceptions=(SalesforceRefusedRequest,),
)
def describe_type(self, name: str) -> Any:
sf_object = SFType(name, self.session_id, self.sf_instance)
result = sf_object.describe()
return result
try:
result = sf_object.describe()
return result
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for describe_type: {name}"
)
# Add additional delay for rate limit errors
time.sleep(3)
raise
def get_queryable_fields_by_type(self, name: str) -> list[str]:
def get_queryable_fields_by_type(self, name: str) -> set[str]:
object_description = self.describe_type(name)
if object_description is None:
return []
return set()
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
@@ -216,7 +286,7 @@ class OnyxSalesforce(Salesforce):
if field_name:
valid_fields.add(field_name)
return list(valid_fields - field_names_to_remove)
return valid_fields - field_names_to_remove
def get_children_of_sf_type(self, sf_type: str) -> dict[str, str]:
"""Returns a dict of child object names to relationship names.

View File

@@ -1,5 +1,6 @@
import gc
import os
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
@@ -7,13 +8,26 @@ from pytz import UTC
from simple_salesforce import Salesforce
from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from simple_salesforce.exceptions import SalesforceRefusedRequest
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
"""Check if an exception is a Salesforce rate limit error."""
return isinstance(
exception, SalesforceRefusedRequest
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
def _build_last_modified_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
@@ -41,12 +55,12 @@ def _build_created_date_time_filter_for_salesforce(
def _make_time_filter_for_sf_type(
queryable_fields: list[str],
queryable_fields: set[str],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> str | None:
if "LastModifiedDate" in queryable_fields:
if MODIFIED_FIELD in queryable_fields:
return _build_last_modified_time_filter_for_salesforce(start, end)
if "CreatedDate" in queryable_fields:
@@ -56,14 +70,14 @@ def _make_time_filter_for_sf_type(
def _make_time_filtered_query(
queryable_fields: list[str], sf_type: str, time_filter: str
queryable_fields: set[str], sf_type: str, time_filter: str
) -> str:
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
return query
def get_object_by_id_query(
object_id: str, sf_type: str, queryable_fields: list[str]
object_id: str, sf_type: str, queryable_fields: set[str]
) -> str:
query = (
f"SELECT {', '.join(queryable_fields)} FROM {sf_type} WHERE Id = '{object_id}'"
@@ -71,6 +85,14 @@ def get_object_by_id_query(
return query
@retry_builder(
tries=5,
delay=2,
backoff=2,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def _object_type_has_api_data(
sf_client: Salesforce, sf_type: str, time_filter: str
) -> bool:
@@ -82,6 +104,15 @@ def _object_type_has_api_data(
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for object type check: {sf_type}"
)
# Add additional delay for rate limit errors
time.sleep(3)
raise
except Exception as e:
if "OPERATION_TOO_LARGE" not in str(e):
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
@@ -163,7 +194,7 @@ def _bulk_retrieve_from_salesforce(
def fetch_all_csvs_in_parallel(
sf_client: Salesforce,
all_types_to_filter: dict[str, bool],
queryable_fields_by_type: dict[str, list[str]],
queryable_fields_by_type: dict[str, set[str]],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
target_dir: str,

View File

@@ -8,11 +8,15 @@ from pathlib import Path
from typing import Any
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
@@ -567,7 +571,7 @@ class OnyxSalesforceSQLite:
uncommitted_rows = 0
# If we're updating User objects, update the email map
if object_type == "User":
if object_type == USER_OBJECT_TYPE:
OnyxSalesforceSQLite._update_user_email_map(cursor)
return updated_ids
@@ -619,7 +623,7 @@ class OnyxSalesforceSQLite:
with self._conn:
cursor = self._conn.cursor()
# Get the object data and account data
if object_type == "Account" or isChild:
if object_type == ACCOUNT_OBJECT_TYPE or isChild:
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
)
@@ -638,7 +642,7 @@ class OnyxSalesforceSQLite:
data = json.loads(result[0][0])
if object_type != "Account":
if object_type != ACCOUNT_OBJECT_TYPE:
# convert any account ids of the relationships back into data fields, with name
for row in result:
@@ -647,14 +651,14 @@ class OnyxSalesforceSQLite:
if len(row) < 3:
continue
if row[1] and row[2] and row[2] == "Account":
if row[1] and row[2] and row[2] == ACCOUNT_OBJECT_TYPE:
data["AccountId"] = row[1]
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?",
(row[1],),
)
account_data = json.loads(cursor.fetchone()[0])
data["Account"] = account_data.get("Name", "")
data[ACCOUNT_OBJECT_TYPE] = account_data.get(NAME_FIELD, "")
return SalesforceObject(id=object_id, type=object_type, data=data)

Some files were not shown because too many files have changed in this diff Show More