Compare commits

..

78 Commits

Author SHA1 Message Date
SubashMohan
5c4f44d258 fix: sharepoint lg files issue (#5065)
* add SharePoint file size threshold check

* Implement retry logic for SharePoint queries to handle rate limiting and server error

* mypy fix

* add content none check

* remove unreachable code from retry logic in sharepoint connector
2025-07-24 14:26:01 +00:00
Evan Lohn
19652ad60e attempt fix for broken excel files (#5071) 2025-07-24 01:21:13 +00:00
Evan Lohn
70c96b6ab3 fix: remove locks from indexing callback (#5070) 2025-07-23 23:05:35 +00:00
Raunak Bhagat
65076b916f refactor: Update location of sidebar (#5067)
* Use props instead of inline type def

* Add new AppProvider

* Remove unused component file

* Move `sessionSidebar` to be inside of `components` instead of `app/chat`

* Change name of `sessionSidebar` to `sidebar`

* Remove `AppModeProvider`

* Fix bug in how the cookies were set
2025-07-23 21:59:34 +00:00
PaulHLiatrio
06bc0e51db fix: adjust template variable from .Chart.AppVersion to .Values.global.version to match versioning pattern. (#5069) 2025-07-23 14:54:32 -07:00
Devin
508b456b40 fix: explicit api_server dependency on minio in docker compose files (#5066) 2025-07-23 13:37:42 -07:00
Evan Lohn
bf1e2a2661 feat: avoid full rerun (#5063)
* fix: remove extra group sync

* second extra task

* minor improvement for non-checkpointed connectors
2025-07-23 18:01:23 +00:00
Evan Lohn
991d5e4203 fix: regen api key (#5064) 2025-07-23 03:36:51 +00:00
Evan Lohn
d21f012b04 fix: remove extra group sync (#5061)
* fix: remove extra group sync

* second extra task
2025-07-22 23:24:42 +00:00
Wenxi
86b7beab01 fix: too many internet chunks (#5060)
* minor internet search env vars

* add limit to internet search chunks

* note

* nits
2025-07-22 23:11:10 +00:00
Evan Lohn
b4eaa81d8b handle empty doc batches (#5058) 2025-07-22 22:35:59 +00:00
Evan Lohn
ff2a4c8723 fix: time discrepancy (#5056)
* fix time discrepancy

* remove log

* remove log
2025-07-22 22:19:02 +00:00
Raunak Bhagat
51027fd259 fix: Make pr-labeler run on edits too 2025-07-22 15:04:37 -07:00
Raunak Bhagat
7e3fd2b12a refactor: Update the error message that is logged when PR title fails Conventional Commits regex (#5062) 2025-07-22 14:46:22 -07:00
Chris Weaver
d2fef6f0b7 Tiny launch.json template improvement (#5055) 2025-07-22 11:15:44 -07:00
Evan Lohn
bd06147d26 feat: connector indexing decoupling (#4893)
* WIP

* renamed and moved tasks (WIP)

* minio migration

* bug fixes and finally add document batch storage

* WIP: can suceed but status is error

* WIP

* import fixes

* working v1 of decoupled

* catastrophe handling

* refactor

* remove unused db session in prep for new approach

* renaming and docstrings (untested)

* renames

* WIP with no more indexing fences

* robustness improvements

* clean up rebase

* migration and salesforce rate limits

* minor tweaks

* test fix

* connector pausing behavior

* correct checkpoint resumption logic

* cleanups in docfetching

* add heartbeat file

* update template jsonc

* deployment fixes

* fix vespa httpx pool

* error handling

* cosmetic fixes

* dumb

* logging improvements and non checkpointed connector fixes

* didnt save

* misc fixes

* fix import

* fix deletion of old files

* add in attempt prefix

* fix attempt prefix

* tiny log improvement

* minor changes

* fixed resumption behavior

* passing int tests

* fix unit test

* fixed unit tests

* trying timeout bump to see if int tests pass

* trying timeout bump to see if int tests pass

* fix autodiscovery

* helm chart fixes

* helm and logging
2025-07-22 03:33:25 +00:00
Raunak Bhagat
1f3cc9ed6e Make from_.user optional (use "Unknown User") if not found (#5051) 2025-07-21 17:50:28 -07:00
Raunak Bhagat
6086d9e51a feat: Updated KG admin page (#5044)
* Update KG admin UI

* Styling changes

* More changes

* Make edits auto-save

* Add more stylings / transitions

* Fix opacity

* Separate out modal into new component

* Revert backend changes

* Update styling

* Add convenience / styling changes to date-picker

* More styling / functional updates to kg admin-page

* Avoid reducing opacity of active-toggle

* Update backend APIs for new KG admin page

* More updates of styling for kg-admin page

* Remove nullability

* Remove console log

* Remove unused imports

* Change type of `children` variable

* Update web/src/app/admin/kg/interfaces.ts

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* Update web/src/components/CollapsibleCard.tsx

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

* Remove null

* Update web/src/components/CollapsibleCard.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Force non-null

* Fix failing test

---------

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-07-21 15:37:27 -07:00
Raunak Bhagat
e0de24f64e Remove empty tooltip (#5050) 2025-07-21 12:45:48 -07:00
Rei Meguro
08b6b1f8b3 feat: Search and Answer Quality Test Script (#4974)
* aefads

* search quality tests improvement

Co-authored-by: wenxi-onyx <wenxi@onyx.app>

* nits

* refactor: config refactor

* document context + skip genai fix

* feat: answer eval

* more error messages

* mypy ragas

* mypy

* small fixes

* feat: more metrics

* fix

* feat: grab content

* typing

* feat: lazy updates

* mypy

* all at front

* feat: answer correctness

* use api key so it works with auth enabled

* update readme

* feat: auto add path

* feat: rate limit

* fix: readme + remove rerank all

* fix: raise exception immediately

* docs: improved clarity

* feat: federated handling

* fix: mypy

* nits

---------

Co-authored-by: wenxi-onyx <wenxi@onyx.app>
2025-07-19 01:51:51 +00:00
joachim-danswer
afed1a4b37 feat: KG improvements (#5048)
* improvements

* drop views if SQL fails

* mypy fix
2025-07-18 16:15:11 -07:00
Chris Weaver
bca18cacdf fix: improve assistant fetching efficiency (#5047)
* Improve assistant fetching efficiency

* More fix

* Fix weird build stuff

* Improve
2025-07-18 14:16:10 -07:00
Chris Weaver
335db91803 fix: improve check for indexing status (#5042)
* Improve check_for_indexing + check_for_vespa_sync_task

* Remove unused

* Fix

* Simplify query

* Add more logging

* Address bot comments

* Increase # of tasks generated since we're not going cc-pair by cc-pair

* Only index 50 user files at a time
2025-07-17 23:52:51 -07:00
Chris Weaver
67c488ff1f Improve support for non-default postgres schemas (#5046) 2025-07-17 23:51:39 -07:00
Wenxi
deb7f13962 remove chat session necessity from send message simple api (#5040) 2025-07-17 23:23:46 +00:00
Raunak Bhagat
e2d3d65c60 fix: Move around group-sync tests (since they require docker services to be running) (#5041)
* Move around tests

* Add missing fixtures + change directory structure up some more

* Add env variables
2025-07-17 22:41:31 +00:00
Raunak Bhagat
b78a6834f5 fix: Have document show up before message starts streaming back (#5006)
* Have document show up before message starts streaming back

* Add docs
2025-07-17 10:17:57 -07:00
Raunak Bhagat
4abe90aa2c fix: Fix Confluence pagination (#5035)
* Re-implement pagination

* Add note

* Fix invalid integration test configs

* Fix other failing test

* Edit failing test

* Revert test

* Revert pagination size

* Add comment on yielding style

* Use fixture instead of manually initializing sql-engine

* Fix failing tests

* Move code back and copy-paste
2025-07-17 14:02:29 +00:00
Raunak Bhagat
de9568844b Add PR labeller job (#4611) 2025-07-16 18:28:18 -07:00
Evan Lohn
34268f9806 fix bug in index swap (#5036) 2025-07-16 23:09:17 +00:00
Chris Weaver
ed75678837 Add suggested helm resource limits (#5032)
* Add resource suggestions for helm

* Adjust README

* fix

* fix lint
2025-07-15 15:52:16 -07:00
Chris Weaver
3bb58a3dd3 Persona simplification r2 (#5031)
* Revert "Revert "Reduce amount of stuff we fetch on `/persona` (#4988)" (#5024)"

This reverts commit f7ed7cd3cd.

* Enhancements / fix re-render

* re-arrange

* greptile
2025-07-15 14:51:40 -07:00
Chris Weaver
4b02feef31 Add option to disable my documents (#5020)
* Add option to disable my documents

* cleanup
2025-07-14 23:16:14 -07:00
Chris Weaver
6a4d49f02e More pruning logging (#5027) 2025-07-14 12:55:12 -07:00
Chris Weaver
d1736187d3 Fix full tenant sync (#5026) 2025-07-14 10:56:40 -07:00
Wenxi
0e79b96091 Feature/revised internet search (#4994)
* remove unused pruning config

* add env vars

* internet search date time toggle

* revised internet search supporting multiple providers

* env var

* simplify retries and fix mypy issues

* greptile nits

* more mypy

* please mypy

* mypy final straw

* cursor vs. mypy

* simplify fields from provider results

* type-safe prompt, enum nit, provider enums, indexingdoc processing change

---------

Co-authored-by: Wenxi Onyx <wenxi-onyx@Wenxis-MacBook-Pro.local>
2025-07-14 10:24:03 -07:00
Raunak Bhagat
ae302d473d Fix imap tests 2025-07-14 09:50:33 -07:00
Raunak Bhagat
feca4fda78 feat: Add frontend for email connector (#5008)
* Add basic structure for frontend email connector

* Update names of credentials-json keys

* Fix up configurations workflow

* Edit logic on how `mail_client` is used

- imaplib.IMAP4_SSL is supposed to be treated as an ephemeral object

* Edit helper name and add docs

* Fix invalid mailbox selection error

* Implement greptile suggestions

* Make recipients optional and add sender to primary-owners

* Add sender to external-access too; perform dedupe-ing of emails

* Simplify logic
2025-07-14 09:43:36 -07:00
Chris Weaver
f7ed7cd3cd Revert "Reduce amount of stuff we fetch on /persona (#4988)" (#5024)
This reverts commit adf48de652.
2025-07-14 09:20:50 -07:00
Chris Weaver
8377ab3ef2 Send over less data for document sets (#5018)
* Send over less data for document sets

* Fix type errors

* Fix tests

* Fixes

* Don't change packages
2025-07-13 22:47:05 +00:00
Chris Weaver
95c23bf870 Add full sync endpoint (#5019)
* Add full sync endpoint

* Update backend/ee/onyx/server/tenants/billing_api.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update backend/ee/onyx/server/tenants/billing_api.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* fix

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-07-13 13:59:19 -07:00
Chris Weaver
e49fb8f56d Fix pruning (#5017)
* Use better last_pruned time for never pruned connectors

* improved pruning req / refresh freq selections

* Small tweak

* Update web/src/lib/connectors/connectors.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-07-12 14:46:00 -07:00
Chris Weaver
adf48de652 Reduce amount of stuff we fetch on /persona (#4988)
* claude stuff

* Send over less Assistant data

* more

* Fix build

* Fix mypy

* fix

* small tweak

* Address EL cmments

* fix

* Fix build
2025-07-12 14:15:31 -07:00
Wenxi Onyx
bca2500438 personas no longer overwrite when same name 2025-07-12 10:57:01 -07:00
Raunak Bhagat
89f925662f feat: Add ability to specify vertex-ai model location (#4955)
* Make constant a global

* Add ability to specify vertex location

* Add period

* Add a hardcoding path to the frontend

* Add docs

* Add default value to `CustomConfigKey`

* Consume default value from custom-config-key on frontend

* Use markdown renderer instead

* Update description
2025-07-11 16:16:12 -07:00
Chris Weaver
b64c6d5d40 Skip federated connectors when document sets are specified (#5015) 2025-07-11 15:49:13 -07:00
Raunak Bhagat
36c63950a6 fix: More small IMAP backend fixes (#5014)
* Make recipients an optional header and add IMAP to recognized connectors

* Add sender to external-access; perform dedupe-ing of emails
2025-07-11 20:06:28 +00:00
Raunak Bhagat
3f31340e6f feat: Add support for Confluence Macros (#5001)
* Remove macro stylings from HTML tree

* Add params

* Handle multiple cases of `ac:structured-macro` being found.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-07-11 00:34:49 +00:00
Raunak Bhagat
6ac2258c2e Fixes for imap backend (#5011) 2025-07-10 23:58:28 +00:00
Weves
b4d3b43e8a Add more error handling for drive group sync 2025-07-10 18:33:43 -07:00
Rei Meguro
ca281b71e3 add missing slack scope 2025-07-10 17:37:57 -07:00
Wenxi
9bd5a1de7a check file size first and clarify processing logic (#4985)
* check file size first and clarify processing logic

* basic gdrive extraction clariy

* typo

---------

Co-authored-by: Wenxi Onyx <wenxi-onyx@Wenxis-MacBook-Pro.local>
2025-07-10 00:48:36 +00:00
Wenxi Onyx
d3c5a4fba0 add docx fallback 2025-07-09 17:46:15 -07:00
Chris Weaver
f50006ee63 Stop fetching channel info to make pages load faster (#5005) 2025-07-09 16:45:45 -07:00
Evan Lohn
e0092024af add minIO to README 2025-07-09 16:36:18 -07:00
Evan Lohn
675ef524b0 add minio to contributing instructions 2025-07-09 16:33:56 -07:00
Evan Lohn
240367c775 skip id migration based on env var 2025-07-09 13:37:54 -07:00
Chris Weaver
f0ed063860 Allow curators to create public connectors / document sets (#4972)
* Allow curators to create public connectors / document sets

* Address EL comments
2025-07-09 11:38:56 -07:00
Rei Meguro
bcf0ef0c87 feat: original query + better slack expansion 2025-07-09 09:23:03 -07:00
Rei Meguro
0c7a245a46 Revert "feat: original query + better slack expansion"
This reverts commit 583d82433a.
2025-07-09 20:15:15 +09:00
Rei Meguro
583d82433a feat: original query + better slack expansion 2025-07-09 20:11:24 +09:00
Chris Weaver
391e710b6e Slack federated search ux (#4969)
* slack_search.py

* rem

* fix: get elements

* feat: better stack message processing

* fix: mypy

* fix: url parsing

* refactor: federated search

* feat: proper chunking + source filters

* highlighting + source check

* feat: forced section insertion

* feat: multi slack api queries

* slack query expansion

* feat: max slack queries env

* limit slack search to avoid overloading the search

* Initial draft

* more

* simpify

* Improve modal

* Fix oauth flow

* Fully working versino

* More nicities

* Improved cascade delete

* document set for fed connector UI

* Fix source filters + improve document set selection

* Improve callback modal

* fix: logging error + showing connectors in admin page user settings

* better log

* Fix mypy

* small rei comment

* Fix pydantic

* Improvements to modals

* feat: distributed pruning

* random fix

* greptile

* Encrypt token

* respect source filters + revert llm pruning ordering

* greptile nit

* feat: thread as context in slack search

* feat: slack doc ordering

* small improvements

* rebase

* Small improvements

* Fix web build

* try fix build

* Move to seaprate model file

* Use default_factory

* remove unused model

---------

Co-authored-by: Rei Meguro <36625832+Orbital-Web@users.noreply.github.com>
2025-07-08 21:35:51 -07:00
Raunak Bhagat
004e56a91b feat: IMAP connector (#4987)
* Implement fetching; still need to work on document parsing

* Add basic skeleton of parsing email bodies

* Add id field

* Add email body parsing

* Implement checkpointed imap-connector

* Add testing logic for basic iteration

* Add logic to get different header if "to" isn't present

- possible in mailing-list workflows

* Add ability to index specific mailboxes

* Add breaking when indexing has been fully exhausted

* Sanitize all mailbox names + add space between stripped strings after parsing

* Add multi-recipient parsing

* Change around semantic-identifier and title

* Add imap tests

* Add recipients and content assertions to tests

* Add envvars to github actions workflow file

* Remove encoding header

* Update logic to not immediately establish connection upon init of `ImapConnector`

* Add start and end datetime filtering + edit when connection is established / how login is done

* Remove content-type header

* Add note about guards

* Change default parameters to be `None` instead of `[]`

* Address comment on PR

* Implement more PR suggestions

* More PR suggestions

* Implement more PR suggestions

* Change up login/logout flow (PR suggestion)

* Move port number to be envvar

* Make globals variants in enum instead (PR suggestion)

* Fix more documentation related suggestions on PR

* Have the imap connector implement `CheckpointedConnectorWithPermSync` instead

* Add helper for loading all docs with permission syncing
2025-07-08 23:58:22 +00:00
Evan Lohn
103300798f Bugfix/drive doc ids3 (#4998)
* fix migration

* fix migration2

* cursor based pages

* correct vespa URL

* fix visit api index name

* use correct endpoint and query
2025-07-07 18:23:00 +00:00
Evan Lohn
8349d6f0ea Bugfix/drive doc ids (#4990)
* fixed id extraction in drive connector

* WIP migration

* full migration script

* migration works single tenant without duplicates

* tested single tenant with duplicate docs

* migrations and frontend

* tested mutlitenant

* fix connector tests

* make tests pass
2025-07-06 01:59:12 +00:00
Emerson Gomes
cd63bf6da9 Re-adding .epub file support (#4989)
.epub files apparently were forgotten and were not allowed for upload in the frontend.
2025-07-05 07:48:55 -07:00
Rei Meguro
5f03e85195 fireflies metadata update (#4993)
* fireflies metadata

* str
2025-07-04 18:39:41 -07:00
Raunak Bhagat
cbdbfcab5e fix: Fix bug with incorrect model icon being shown (#4986)
* Fix bug with incorrect model icon being shown

* Update web/src/app/chat/input/LLMPopover.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update web/src/app/chat/input/LLMPopover.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update web/src/app/chat/input/LLMPopover.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update web/src/app/chat/input/LLMPopover.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Add visibility to filtering

* Update the model names which are shown in the popup

* Fix incorrect llm updating bug

* Fix bug in which the provider name would be used instead

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-07-04 01:39:50 +00:00
SubashMohan
6918611287 remove check for folder assistant before uploading (#4975)
Co-authored-by: Subash <subash@onyx.app>
2025-07-03 09:25:03 -07:00
Chris Weaver
b0639add8f Fix migration (#4982) 2025-07-03 09:18:57 -07:00
Evan Lohn
7af10308d7 drive service account shared fixes (#4977)
* drive service account shared fixes

* oops

* ily greptile

* scrollable index attempt errors

* tentatively correct index errors page, needs testing

* mypy

* black

* better bounds in practice

* remove random failures

* remove console log

* CW
2025-07-02 16:56:32 -07:00
Rei Meguro
5e14f23507 mypy fix 2025-07-01 23:00:33 -07:00
Raunak Bhagat
0bf3a5c609 Add type ignore for dynamic sqlalchemy class (#4979) 2025-07-01 18:14:35 -07:00
Emerson Gomes
82724826ce Remove hardcoded image extraction flag for PDFs
PDFs currently always have their images extracted.
This will make use of the "Enable Image Extraction and Analysis" workspace configuration instead.
2025-07-01 13:57:36 -07:00
Wenxi
f9e061926a account for category prefix added by user (#4976)
Co-authored-by: Wenxi Onyx <wenxi-onyx@Wenxis-MacBook-Pro.local>
2025-07-01 10:39:46 -07:00
Chris Weaver
8afd07ff7a Small gdrive perm sync enhancement (#4973)
* Small gdrive perm sync enhancement

* Small enhancement
2025-07-01 09:33:45 -07:00
Evan Lohn
6523a38255 search speedup (#4971) 2025-07-01 01:41:27 +00:00
Yuhong Sun
264878a1c9 Onyx Metadata Header for File Connector (#4968) 2025-06-29 16:09:06 -07:00
365 changed files with 17328 additions and 6649 deletions

View File

@@ -13,6 +13,14 @@ env:
# MinIO
S3_ENDPOINT_URL: "http://localhost:9004"
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
discover-test-dirs:
runs-on: ubuntu-latest

38
.github/workflows/pr-labeler.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: PR Labeler
on:
pull_request_target:
branches:
- main
types:
- opened
- reopened
- synchronize
- edited
permissions:
contents: read
pull-requests: write
jobs:
validate_pr_title:
runs-on: ubuntu-latest
steps:
- name: Check PR title for Conventional Commits
env:
PR_TITLE: ${{ github.event.pull_request.title }}
run: |
echo "PR Title: $PR_TITLE"
if [[ ! "$PR_TITLE" =~ ^(feat|fix|docs|test|ci|refactor|perf|chore|revert|build)(\(.+\))?:\ .+ ]]; then
echo "::error::❌ Your PR title does not follow the Conventional Commits format.
This check ensures that all pull requests use clear, consistent titles that help automate changelogs and improve project history.
Please update your PR title to follow the Conventional Commits style.
Here is a link to a blog explaining the reason why we've included the Conventional Commits style into our PR titles: https://xfuture-blog.com/working-with-conventional-commits
**Here are some examples of valid PR titles:**
- feat: add user authentication
- fix(login): handle null password error
- docs(readme): update installation instructions"
exit 1
fi

View File

@@ -47,7 +47,7 @@ jobs:
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client
--package-name onyx_openapi_client \
- name: Run MyPy
run: |

View File

@@ -16,8 +16,8 @@ env:
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -53,6 +53,12 @@ env:
# Hubspot
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
# IMAP
IMAP_HOST: ${{ secrets.IMAP_HOST }}
IMAP_USERNAME: ${{ secrets.IMAP_USERNAME }}
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
IMAP_MAILBOXES: ${{ secrets.IMAP_MAILBOXES }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}

View File

@@ -45,8 +45,9 @@ PYTHONPATH=../backend
PYTHONUNBUFFERED=1
# Internet Search
# Internet Search
BING_API_KEY=<REPLACE THIS>
EXA_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features

View File

@@ -24,8 +24,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring"
],
@@ -46,8 +46,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring"
],
@@ -226,35 +226,66 @@
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"name": "Celery docfetching",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
],
"presentation": {
"group": "2"
"group": "2"
},
"consoleTitle": "Celery indexing Console"
},
"consoleTitle": "Celery docfetching Console",
"justMyCode": false
},
{
"name": "Celery docprocessing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
"-Q",
"docprocessing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
},
{
"name": "Celery monitoring",
"type": "debugpy",
@@ -303,35 +334,6 @@
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
@@ -426,7 +428,7 @@
},
"args": [
"--filename",
"generated/openapi.json",
"generated/openapi.json"
]
},
{

View File

@@ -59,6 +59,7 @@ Onyx being a fully functional app, relies on some external software, specificall
- [Postgres](https://www.postgresql.org/) (Relational DB)
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
- [Redis](https://redis.io/) (Cache)
- [MinIO](https://min.io/) (File Store)
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
> **Note:**
@@ -171,10 +172,10 @@ Otherwise, you can follow the instructions below to run the application for deve
You will need Docker installed to run these containers.
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
```bash
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache minio
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)

View File

@@ -23,7 +23,7 @@ from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import (
MULTI_TENANT,
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE,
POSTGRES_DEFAULT_SCHEMA,
TENANT_ID_PREFIX,
)
from onyx.db.models import Base
@@ -271,7 +271,7 @@ async def run_async_migrations() -> None:
) = get_schema_options()
if not schemas and not MULTI_TENANT:
schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE]
schemas = [POSTGRES_DEFAULT_SCHEMA]
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)

View File

@@ -0,0 +1,72 @@
"""add federated connector tables
Revision ID: 0816326d83aa
Revises: 12635f6655b7
Create Date: 2025-06-29 14:09:45.109518
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "0816326d83aa"
down_revision = "12635f6655b7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create federated_connector table
op.create_table(
"federated_connector",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("source", sa.String(), nullable=False),
sa.Column("credentials", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create federated_connector_oauth_token table
op.create_table(
"federated_connector_oauth_token",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("token", sa.LargeBinary(), nullable=False),
sa.Column("expires_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Create federated_connector__document_set table
op.create_table(
"federated_connector__document_set",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
sa.Column("document_set_id", sa.Integer(), nullable=False),
sa.Column("entities", postgresql.JSONB(), nullable=False),
sa.ForeignKeyConstraint(
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["document_set_id"], ["document_set.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"federated_connector_id",
"document_set_id",
name="uq_federated_connector_document_set",
),
)
def downgrade() -> None:
# Drop tables in reverse order due to foreign key dependencies
op.drop_table("federated_connector__document_set")
op.drop_table("federated_connector_oauth_token")
op.drop_table("federated_connector")

View File

@@ -0,0 +1,596 @@
"""drive-canonical-ids
Revision ID: 12635f6655b7
Revises: 58c50ef19f08
Create Date: 2025-06-20 14:44:54.241159
"""
from alembic import op
import sqlalchemy as sa
from urllib.parse import urlparse, urlunparse
from httpx import HTTPStatusError
import httpx
from onyx.document_index.factory import get_default_document_index
from onyx.db.search_settings import SearchSettings
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.utils.logger import setup_logger
import os
logger = setup_logger()
# revision identifiers, used by Alembic.
revision = "12635f6655b7"
down_revision = "58c50ef19f08"
branch_labels = None
depends_on = None
SKIP_CANON_DRIVE_IDS = os.environ.get("SKIP_CANON_DRIVE_IDS", "true").lower() == "true"
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
result = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_fetch = result.fetchall()
search_settings = (
SearchSettings(**search_settings_fetch[0]._asdict())
if search_settings_fetch
else None
)
result2 = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_future_fetch = result2.fetchall()
search_settings_future = (
SearchSettings(**search_settings_future_fetch[0]._asdict())
if search_settings_future_fetch
else None
)
if not isinstance(search_settings, SearchSettings):
raise RuntimeError(
"current search settings is of type " + str(type(search_settings))
)
if (
not isinstance(search_settings_future, SearchSettings)
and search_settings_future is not None
):
raise RuntimeError(
"future search settings is of type " + str(type(search_settings_future))
)
return search_settings, search_settings_future
def normalize_google_drive_url(url: str) -> str:
"""Remove query parameters from Google Drive URLs to create canonical document IDs.
NOTE: copied from drive doc_conversion.py
"""
parsed_url = urlparse(url)
parsed_url = parsed_url._replace(query="")
spl_path = parsed_url.path.split("/")
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
spl_path.pop()
parsed_url = parsed_url._replace(path="/".join(spl_path))
# Remove query parameters and reconstruct URL
return urlunparse(parsed_url)
def get_google_drive_documents_from_database() -> list[dict]:
"""Get all Google Drive documents from the database."""
bind = op.get_bind()
result = bind.execute(
sa.text(
"""
SELECT d.id
FROM document d
JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id
JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id
AND dcc.credential_id = cc.credential_id
JOIN connector c ON cc.connector_id = c.id
WHERE c.source = 'GOOGLE_DRIVE'
"""
)
)
documents = []
for row in result:
documents.append({"document_id": row.id})
return documents
def update_document_id_in_database(
old_doc_id: str, new_doc_id: str, index_name: str
) -> None:
"""Update document IDs in all relevant database tables using copy-and-swap approach."""
bind = op.get_bind()
# print(f"Updating database tables for document {old_doc_id} -> {new_doc_id}")
# Check if new document ID already exists
result = bind.execute(
sa.text("SELECT COUNT(*) FROM document WHERE id = :new_id"),
{"new_id": new_doc_id},
)
row = result.fetchone()
if row and row[0] > 0:
# print(f"Document with ID {new_doc_id} already exists, deleting old one")
delete_document_from_db(old_doc_id, index_name)
return
# Step 1: Create a new document row with the new ID (copy all fields from old row)
# Use a conservative approach to handle columns that might not exist in all installations
try:
bind.execute(
sa.text(
"""
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners,
external_user_emails, external_user_group_ids, is_public,
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time)
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners,
external_user_emails, external_user_group_ids, is_public,
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time
FROM document
WHERE id = :old_id
"""
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated database tables for document {old_doc_id} -> {new_doc_id}")
except Exception as e:
# If the full INSERT fails, try a more basic version with only core columns
logger.warning(f"Full INSERT failed, trying basic version: {e}")
bind.execute(
sa.text(
"""
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners)
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners
FROM document
WHERE id = :old_id
"""
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# Step 2: Update all foreign key references to point to the new ID
# Update document_by_connector_credential_pair table
bind.execute(
sa.text(
"UPDATE document_by_connector_credential_pair SET id = :new_id WHERE id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document_by_connector_credential_pair table for document {old_doc_id} -> {new_doc_id}")
# Update search_doc table (stores search results for chat replay)
# This is critical for agent functionality
bind.execute(
sa.text(
"UPDATE search_doc SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated search_doc table for document {old_doc_id} -> {new_doc_id}")
# Update document_retrieval_feedback table (user feedback on documents)
bind.execute(
sa.text(
"UPDATE document_retrieval_feedback SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document_retrieval_feedback table for document {old_doc_id} -> {new_doc_id}")
# Update document__tag table (document-tag relationships)
bind.execute(
sa.text(
"UPDATE document__tag SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document__tag table for document {old_doc_id} -> {new_doc_id}")
# Update user_file table (user uploaded files linked to documents)
bind.execute(
sa.text(
"UPDATE user_file SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated user_file table for document {old_doc_id} -> {new_doc_id}")
# Update KG and chunk_stats tables (these may not exist in all installations)
try:
# Update kg_entity table
bind.execute(
sa.text(
"UPDATE kg_entity SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_entity table for document {old_doc_id} -> {new_doc_id}")
# Update kg_entity_extraction_staging table
bind.execute(
sa.text(
"UPDATE kg_entity_extraction_staging SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_entity_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
# Update kg_relationship table
bind.execute(
sa.text(
"UPDATE kg_relationship SET source_document = :new_id WHERE source_document = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_relationship table for document {old_doc_id} -> {new_doc_id}")
# Update kg_relationship_extraction_staging table
bind.execute(
sa.text(
"UPDATE kg_relationship_extraction_staging SET source_document = :new_id WHERE source_document = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_relationship_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
# Update chunk_stats table
bind.execute(
sa.text(
"UPDATE chunk_stats SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated chunk_stats table for document {old_doc_id} -> {new_doc_id}")
# Update chunk_stats ID field which includes document_id
bind.execute(
sa.text(
"""
UPDATE chunk_stats
SET id = REPLACE(id, :old_id, :new_id)
WHERE id LIKE :old_id_pattern
"""
),
{
"new_id": new_doc_id,
"old_id": old_doc_id,
"old_id_pattern": f"{old_doc_id}__%",
},
)
# print(f"Successfully updated chunk_stats ID field for document {old_doc_id} -> {new_doc_id}")
except Exception as e:
logger.warning(f"Some KG/chunk tables may not exist or failed to update: {e}")
# Step 3: Delete the old document row (this should now be safe since all FKs point to new row)
bind.execute(
sa.text("DELETE FROM document WHERE id = :old_id"), {"old_id": old_doc_id}
)
# print(f"Successfully deleted document {old_doc_id} from database")
def _visit_chunks(
*,
http_client: httpx.Client,
index_name: str,
selection: str,
continuation: str | None = None,
) -> tuple[list[dict], str | None]:
"""Helper that calls the /document/v1 visit API once and returns (docs, next_token)."""
# Use the same URL as the document API, but with visit-specific params
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
params: dict[str, str] = {
"selection": selection,
"wantedDocumentCount": "1000",
}
if continuation:
params["continuation"] = continuation
# print(f"Visiting chunks for selection '{selection}' with params {params}")
resp = http_client.get(base_url, params=params, timeout=None)
# print(f"Visited chunks for document {selection}")
resp.raise_for_status()
payload = resp.json()
return payload.get("documents", []), payload.get("continuation")
def delete_document_chunks_from_vespa(index_name: str, doc_id: str) -> None:
"""Delete all chunks for *doc_id* from Vespa using continuation-token paging (no offset)."""
total_deleted = 0
# Use exact match instead of contains - Document Selector Language doesn't support contains
selection = f'{index_name}.document_id=="{doc_id}"'
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
delete_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
try:
resp = http_client.delete(delete_url)
resp.raise_for_status()
total_deleted += 1
except Exception as e:
print(f"Failed to delete chunk {vespa_doc_uuid}: {e}")
if not continuation:
break
def update_document_id_in_vespa(
index_name: str, old_doc_id: str, new_doc_id: str
) -> None:
"""Update all chunks' document_id field from *old_doc_id* to *new_doc_id* using continuation paging."""
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
# Use exact match instead of contains - Document Selector Language doesn't support contains
selection = f'{index_name}.document_id=="{old_doc_id}"'
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
# print(f"Visiting chunks for document {old_doc_id} -> {new_doc_id}")
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
update_request = {
"fields": {"document_id": {"assign": clean_new_doc_id}}
}
try:
resp = http_client.put(vespa_url, json=update_request)
resp.raise_for_status()
except Exception as e:
print(f"Failed to update chunk {vespa_doc_uuid}: {e}")
raise
if not continuation:
break
def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
# Delete all foreign key references first, then delete the document
try:
bind = op.get_bind()
# Delete from agent-related tables first (order matters due to foreign keys)
# Delete from agent__sub_query__search_doc first since it references search_doc
bind.execute(
sa.text(
"""
DELETE FROM agent__sub_query__search_doc
WHERE search_doc_id IN (
SELECT id FROM search_doc WHERE document_id = :doc_id
)
"""
),
{"doc_id": current_doc_id},
)
# Delete from chat_message__search_doc
bind.execute(
sa.text(
"""
DELETE FROM chat_message__search_doc
WHERE search_doc_id IN (
SELECT id FROM search_doc WHERE document_id = :doc_id
)
"""
),
{"doc_id": current_doc_id},
)
# Now we can safely delete from search_doc
bind.execute(
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete from document_by_connector_credential_pair
bind.execute(
sa.text(
"DELETE FROM document_by_connector_credential_pair WHERE id = :doc_id"
),
{"doc_id": current_doc_id},
)
# Delete from other tables that reference this document
bind.execute(
sa.text(
"DELETE FROM document_retrieval_feedback WHERE document_id = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM document__tag WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM user_file WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete from KG tables if they exist
try:
bind.execute(
sa.text("DELETE FROM kg_entity WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text(
"DELETE FROM kg_entity_extraction_staging WHERE document_id = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM kg_relationship WHERE source_document = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text(
"DELETE FROM kg_relationship_extraction_staging WHERE source_document = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM chunk_stats WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM chunk_stats WHERE id LIKE :doc_id_pattern"),
{"doc_id_pattern": f"{current_doc_id}__%"},
)
except Exception as e:
logger.warning(
f"Some KG/chunk tables may not exist or failed to delete from: {e}"
)
# Finally delete the document itself
bind.execute(
sa.text("DELETE FROM document WHERE id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete chunks from vespa
delete_document_chunks_from_vespa(index_name, current_doc_id)
except Exception as e:
print(f"Failed to delete duplicate document {current_doc_id}: {e}")
# Continue with other documents instead of failing the entire migration
def upgrade() -> None:
if SKIP_CANON_DRIVE_IDS:
return
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings,
future_search_settings,
)
# Get the index name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"
# Get all Google Drive documents from the database (this is faster and more reliable)
gdrive_documents = get_google_drive_documents_from_database()
if not gdrive_documents:
return
# Track normalized document IDs to detect duplicates
all_normalized_doc_ids = set()
updated_count = 0
for doc_info in gdrive_documents:
current_doc_id = doc_info["document_id"]
normalized_doc_id = normalize_google_drive_url(current_doc_id)
print(f"Processing document {current_doc_id} -> {normalized_doc_id}")
# Check for duplicates
if normalized_doc_id in all_normalized_doc_ids:
# print(f"Deleting duplicate document {current_doc_id}")
delete_document_from_db(current_doc_id, index_name)
continue
all_normalized_doc_ids.add(normalized_doc_id)
# If the document ID already doesn't have query parameters, skip it
if current_doc_id == normalized_doc_id:
# print(f"Skipping document {current_doc_id} -> {normalized_doc_id} because it already has no query parameters")
continue
try:
# Update both database and Vespa in order
# Database first to ensure consistency
update_document_id_in_database(
current_doc_id, normalized_doc_id, index_name
)
# For Vespa, we can now use the original document IDs since we're using contains matching
update_document_id_in_vespa(index_name, current_doc_id, normalized_doc_id)
updated_count += 1
# print(f"Finished updating document {current_doc_id} -> {normalized_doc_id}")
except Exception as e:
print(f"Failed to update document {current_doc_id}: {e}")
if isinstance(e, HTTPStatusError):
print(f"HTTPStatusError: {e}")
print(f"Response: {e.response.text}")
print(f"Status: {e.response.status_code}")
print(f"Headers: {e.response.headers}")
print(f"Request: {e.request.url}")
print(f"Request headers: {e.request.headers}")
# Note: Rollback is complex with copy-and-swap approach since the old document is already deleted
# In case of failure, manual intervention may be required
# Continue with other documents instead of failing the entire migration
continue
logger.info(f"Migration complete. Updated {updated_count} Google Drive documents")
def downgrade() -> None:
# this is a one way migration, so no downgrade.
# It wouldn't make sense to store the extra query parameters
# and duplicate documents to allow a reversal.
pass

View File

@@ -144,27 +144,34 @@ def upgrade() -> None:
def downgrade() -> None:
op.execute("TRUNCATE TABLE index_attempt")
op.add_column(
"index_attempt",
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
)
op.add_column(
"index_attempt",
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
)
op.add_column(
"index_attempt",
sa.Column(
"connector_specific_config",
postgresql.JSONB(astext_type=sa.Text()),
autoincrement=False,
nullable=False,
),
)
# Check if the constraint exists before dropping
conn = op.get_bind()
inspector = sa.inspect(conn)
existing_columns = {col["name"] for col in inspector.get_columns("index_attempt")}
if "input_type" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
)
if "source" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
)
if "connector_specific_config" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column(
"connector_specific_config",
postgresql.JSONB(astext_type=sa.Text()),
autoincrement=False,
nullable=False,
),
)
# Check if the constraint exists before dropping
constraints = inspector.get_foreign_keys("index_attempt")
if any(
@@ -183,8 +190,12 @@ def downgrade() -> None:
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_pair")
op.drop_table("credential")
op.drop_table("connector")
if "credential_id" in existing_columns:
op.drop_column("index_attempt", "credential_id")
if "connector_id" in existing_columns:
op.drop_column("index_attempt", "connector_id")
op.execute("DROP TABLE IF EXISTS connector_credential_pair CASCADE")
op.execute("DROP TABLE IF EXISTS credential CASCADE")
op.execute("DROP TABLE IF EXISTS connector CASCADE")

View File

@@ -0,0 +1,115 @@
"""add_indexing_coordination
Revision ID: 2f95e36923e6
Revises: 0816326d83aa
Create Date: 2025-07-10 16:17:57.762182
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f95e36923e6"
down_revision = "0816326d83aa"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add database-based coordination fields (replacing Redis fencing)
op.add_column(
"index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"cancellation_requested",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
# Add batch coordination fields (replacing FileStore state)
op.add_column(
"index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"completed_batches", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"total_failures_batch_level",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"index_attempt",
sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"),
)
# Progress tracking for stall detection
op.add_column(
"index_attempt",
sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column(
"last_batches_completed_count",
sa.Integer(),
nullable=False,
server_default="0",
),
)
# Heartbeat tracking for worker liveness detection
op.add_column(
"index_attempt",
sa.Column(
"heartbeat_counter", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"last_heartbeat_value", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True),
)
# Add index for coordination queries
op.create_index(
"ix_index_attempt_active_coordination",
"index_attempt",
["connector_credential_pair_id", "search_settings_id", "status"],
)
def downgrade() -> None:
# Remove the new index
op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt")
# Remove the new columns
op.drop_column("index_attempt", "last_batches_completed_count")
op.drop_column("index_attempt", "last_progress_time")
op.drop_column("index_attempt", "last_heartbeat_time")
op.drop_column("index_attempt", "last_heartbeat_value")
op.drop_column("index_attempt", "heartbeat_counter")
op.drop_column("index_attempt", "total_chunks")
op.drop_column("index_attempt", "total_failures_batch_level")
op.drop_column("index_attempt", "completed_batches")
op.drop_column("index_attempt", "total_batches")
op.drop_column("index_attempt", "cancellation_requested")
op.drop_column("index_attempt", "celery_task_id")

View File

@@ -9,7 +9,7 @@ Create Date: 2025-06-22 17:33:25.833733
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# revision identifiers, used by Alembic.
revision = "36e9220ab794"
@@ -66,7 +66,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -111,7 +111,7 @@ def upgrade() -> None:
UPDATE "{tenant_id}".kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;

View File

@@ -15,7 +15,7 @@ from datetime import datetime, timedelta
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# revision identifiers, used by Alembic.
@@ -80,6 +80,7 @@ def upgrade() -> None:
)
)
op.execute("DROP TABLE IF EXISTS kg_config CASCADE")
op.create_table(
"kg_config",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
@@ -123,6 +124,7 @@ def upgrade() -> None:
],
)
op.execute("DROP TABLE IF EXISTS kg_entity_type CASCADE")
op.create_table(
"kg_entity_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
@@ -156,6 +158,7 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_relationship_type CASCADE")
# Create KGRelationshipType table
op.create_table(
"kg_relationship_type",
@@ -194,6 +197,7 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_relationship_type_extraction_staging CASCADE")
# Create KGRelationshipTypeExtractionStaging table
op.create_table(
"kg_relationship_type_extraction_staging",
@@ -227,6 +231,8 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_entity CASCADE")
# Create KGEntity table
op.create_table(
"kg_entity",
@@ -281,6 +287,7 @@ def upgrade() -> None:
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
)
op.execute("DROP TABLE IF EXISTS kg_entity_extraction_staging CASCADE")
# Create KGEntityExtractionStaging table
op.create_table(
"kg_entity_extraction_staging",
@@ -330,6 +337,7 @@ def upgrade() -> None:
["name", "entity_type_id_name"],
)
op.execute("DROP TABLE IF EXISTS kg_relationship CASCADE")
# Create KGRelationship table
op.create_table(
"kg_relationship",
@@ -371,6 +379,7 @@ def upgrade() -> None:
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
)
op.execute("DROP TABLE IF EXISTS kg_relationship_extraction_staging CASCADE")
# Create KGRelationshipExtractionStaging table
op.create_table(
"kg_relationship_extraction_staging",
@@ -414,6 +423,7 @@ def upgrade() -> None:
["source_node", "target_node"],
)
op.execute("DROP TABLE IF EXISTS kg_term CASCADE")
# Create KGTerm table
op.create_table(
"kg_term",
@@ -468,7 +478,7 @@ def upgrade() -> None:
# Create GIN index for clustering and normalization
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
@@ -508,7 +518,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -553,7 +563,7 @@ def upgrade() -> None:
UPDATE kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;

View File

@@ -159,7 +159,7 @@ def _migrate_files_to_postgres() -> None:
# only create external store if we have files to migrate. This line
# makes it so we need to have S3/MinIO configured to run this migration.
external_store = get_s3_file_store(db_session=session)
external_store = get_s3_file_store()
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
@@ -219,7 +219,7 @@ def _migrate_files_to_external_storage() -> None:
# Get database session
bind = op.get_bind()
session = Session(bind=bind)
external_store = get_s3_file_store(db_session=session)
external_store = get_s3_file_store()
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
result = session.execute(
@@ -236,6 +236,9 @@ def _migrate_files_to_external_storage() -> None:
print("No files found in PostgreSQL storage to migrate.")
return
# might need to move this above the if statement when creating a new multi-tenant
# system. VERY extreme edge case.
external_store.initialize()
print(f"Found {total_files} files to migrate from PostgreSQL to external storage.")
_set_tenant_contextvar(session)

View File

@@ -18,11 +18,13 @@ depends_on: None = None
def upgrade() -> None:
op.execute("DROP TABLE IF EXISTS document CASCADE")
op.create_table(
"document",
sa.Column("id", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.execute("DROP TABLE IF EXISTS chunk CASCADE")
op.create_table(
"chunk",
sa.Column("id", sa.String(), nullable=False),
@@ -43,6 +45,7 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id", "document_store_type"),
)
op.execute("DROP TABLE IF EXISTS deletion_attempt CASCADE")
op.create_table(
"deletion_attempt",
sa.Column("id", sa.Integer(), nullable=False),
@@ -84,6 +87,7 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id"),
)
op.execute("DROP TABLE IF EXISTS document_by_connector_credential_pair CASCADE")
op.create_table(
"document_by_connector_credential_pair",
sa.Column("id", sa.String(), nullable=False),
@@ -106,7 +110,10 @@ def upgrade() -> None:
def downgrade() -> None:
# upstream tables first
op.drop_table("document_by_connector_credential_pair")
op.drop_table("deletion_attempt")
op.drop_table("chunk")
op.drop_table("document")
# Alembic op.drop_table() has no "cascade" flag issue raw SQL
op.execute("DROP TABLE IF EXISTS document CASCADE")

View File

@@ -91,7 +91,7 @@ def export_query_history_task(
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store(db_session).save_file(
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -422,7 +422,7 @@ def connector_permission_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
+ f"_{redis_connector.cc_pair_id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
thread_local=False,
)

View File

@@ -383,7 +383,7 @@ def connector_external_group_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
+ f"_{redis_connector.cc_pair_id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)

View File

@@ -114,7 +114,6 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
def get_usage_report_data(
db_session: Session,
report_display_name: str,
) -> IO:
"""
@@ -128,7 +127,7 @@ def get_usage_report_data(
Returns:
The usage report data.
"""
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
# usage report may be very large, so don't load it all into memory
return file_store.read_file(
file_id=report_display_name, mode="b", use_tempfile=True

View File

@@ -128,11 +128,14 @@ def validate_object_creation_for_user(
target_group_ids: list[int] | None = None,
object_is_public: bool | None = None,
object_is_perm_sync: bool | None = None,
object_is_owned_by_user: bool = False,
object_is_new: bool = False,
) -> None:
"""
All users can create/edit permission synced objects if they don't specify a group
All admin actions are allowed.
Prevents non-admins from creating/editing:
Curators and global curators can create public objects.
Prevents other non-admins from creating/editing:
- public objects
- objects with no groups
- objects that belong to a group they don't curate
@@ -143,13 +146,23 @@ def validate_object_creation_for_user(
if not user or user.role == UserRole.ADMIN:
return
if object_is_public:
detail = "User does not have permission to create public credentials"
# Allow curators and global curators to create public objects
# w/o associated groups IF the object is new/owned by them
if (
object_is_public
and user.role in [UserRole.CURATOR, UserRole.GLOBAL_CURATOR]
and (object_is_new or object_is_owned_by_user)
):
return
if object_is_public and user.role == UserRole.BASIC:
detail = "User does not have permission to create public objects"
logger.error(detail)
raise HTTPException(
status_code=400,
detail=detail,
)
if not target_group_ids:
detail = "Curators must specify 1+ groups"
logger.error(detail)

View File

@@ -40,8 +40,28 @@ def _get_slim_doc_generator(
)
def _merge_permissions_lists(
permission_lists: list[list[GoogleDrivePermission]],
) -> list[GoogleDrivePermission]:
"""
Merge a list of permission lists into a single list of permissions.
"""
seen_permission_ids: set[str] = set()
merged_permissions: list[GoogleDrivePermission] = []
for permission_list in permission_lists:
for permission in permission_list:
if permission.id not in seen_permission_ids:
merged_permissions.append(permission)
seen_permission_ids.add(permission.id)
return merged_permissions
def get_external_access_for_raw_gdrive_file(
file: GoogleDriveFileType, company_domain: str, drive_service: GoogleDriveService
file: GoogleDriveFileType,
company_domain: str,
retriever_drive_service: GoogleDriveService | None,
admin_drive_service: GoogleDriveService,
) -> ExternalAccess:
"""
Get the external access for a raw Google Drive file.
@@ -62,11 +82,28 @@ def get_external_access_for_raw_gdrive_file(
GoogleDrivePermission.from_drive_permission(p) for p in permissions
]
elif permission_ids:
permissions_list = get_permissions_by_ids(
drive_service=drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
def _get_permissions(
drive_service: GoogleDriveService,
) -> list[GoogleDrivePermission]:
return get_permissions_by_ids(
drive_service=drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
)
permissions_list = _get_permissions(
retriever_drive_service or admin_drive_service
)
if len(permissions_list) != len(permission_ids) and retriever_drive_service:
logger.warning(
f"Failed to get all permissions for file {doc_id} with retriever service, "
"trying admin service"
)
backup_permissions_list = _get_permissions(admin_drive_service)
permissions_list = _merge_permissions_lists(
[permissions_list, backup_permissions_list]
)
folder_ids_to_inherit_permissions_from: set[str] = set()
user_emails: set[str] = set()

View File

@@ -44,11 +44,17 @@ def _get_all_folders(
TODO: tweak things so we can fetch deltas.
"""
MAX_FAILED_PERCENTAGE = 0.5
all_folders: list[FolderInfo] = []
seen_folder_ids: set[str] = set()
user_emails = google_drive_connector._get_all_user_emails()
for user_email in user_emails:
def _get_all_folders_for_user(
google_drive_connector: GoogleDriveConnector,
skip_folders_without_permissions: bool,
user_email: str,
) -> None:
"""Helper to get folders for a specific user + update shared seen_folder_ids"""
drive_service = get_drive_service(
google_drive_connector.creds,
user_email,
@@ -98,6 +104,20 @@ def _get_all_folders(
)
)
failed_count = 0
user_emails = google_drive_connector._get_all_user_emails()
for user_email in user_emails:
try:
_get_all_folders_for_user(
google_drive_connector, skip_folders_without_permissions, user_email
)
except Exception:
logger.exception(f"Error getting folders for user {user_email}")
failed_count += 1
if failed_count > MAX_FAILED_PERCENTAGE * len(user_emails):
raise RuntimeError("Too many failed folder fetches during group sync")
return all_folders

View File

@@ -134,15 +134,14 @@ def ee_fetch_settings() -> EnterpriseSettings:
def put_logo(
file: UploadFile,
is_logotype: bool = False,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
upload_logo(file=file, is_logotype=is_logotype)
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")

View File

@@ -6,7 +6,6 @@ from typing import IO
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
@@ -99,9 +98,7 @@ def guess_file_type(filename: str) -> str:
return "application/octet-stream"
def upload_logo(
db_session: Session, file: UploadFile | str, is_logotype: bool = False
) -> bool:
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
content: IO[Any]
if isinstance(file, str):
@@ -129,7 +126,7 @@ def upload_logo(
display_name = file.filename
file_type = file.content_type or "image/jpeg"
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.save_file(
content=content,
display_name=display_name,

View File

@@ -1,5 +1,6 @@
import re
from typing import cast
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
@@ -73,6 +74,7 @@ def _get_final_context_doc_indices(
def _convert_packet_stream_to_response(
packets: ChatPacketStream,
chat_session_id: UUID,
) -> ChatBasicResponse:
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
@@ -216,6 +218,8 @@ def _convert_packet_stream_to_response(
if answer:
response.answer_citationless = remove_answer_citations(answer)
response.chat_session_id = chat_session_id
return response
@@ -237,13 +241,36 @@ def handle_simplified_chat_message(
if not chat_message_req.message:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
# Handle chat session creation if chat_session_id is not provided
if chat_message_req.chat_session_id is None:
if chat_message_req.persona_id is None:
raise HTTPException(
status_code=400,
detail="Either chat_session_id or persona_id must be provided",
)
# Create a new chat session with the provided persona_id
try:
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave empty for simple API
user_id=user.id if user else None,
persona_id=chat_message_req.persona_id,
)
chat_session_id = new_chat_session.id
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
else:
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_chain(
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
chat_session_id=chat_session_id, db_session=db_session
)
except Exception:
parent_message = get_or_create_root_message(
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
chat_session_id=chat_session_id, db_session=db_session
)
if (
@@ -258,7 +285,7 @@ def handle_simplified_chat_message(
retrieval_options = chat_message_req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_message_req.chat_session_id,
chat_session_id=chat_session_id,
parent_message_id=parent_message.id,
message=chat_message_req.message,
file_descriptors=[],
@@ -283,7 +310,7 @@ def handle_simplified_chat_message(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets)
return _convert_packet_stream_to_response(packets, chat_session_id)
@router.post("/send-message-simple-with-history")
@@ -403,4 +430,4 @@ def handle_send_message_simple_with_history(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets)
return _convert_packet_stream_to_response(packets, chat_session.id)

View File

@@ -41,11 +41,13 @@ class DocumentSearchRequest(ChunkContext):
class BasicCreateChatMessageRequest(ChunkContext):
"""Before creating messages, be sure to create a chat_session and get an id
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
Note, for simplicity this option only allows for a single linear chain of messages
"""
chat_session_id: UUID
chat_session_id: UUID | None = None
# Optional persona_id to create a new chat session if chat_session_id is not provided
persona_id: int | None = None
# New message contents
message: str
# Defaults to using retrieval with no additional filters
@@ -62,6 +64,12 @@ class BasicCreateChatMessageRequest(ChunkContext):
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
if self.chat_session_id is None and self.persona_id is None:
raise ValueError("Either chat_session_id or persona_id must be provided")
return self
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -171,6 +179,9 @@ class ChatBasicResponse(BaseModel):
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
# Chat session ID for tracking conversation continuity
chat_session_id: UUID | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -358,7 +358,7 @@ def get_query_history_export_status(
# If task is None, then it's possible that the task has already finished processing.
# Therefore, we should then check if the export file has already been stored inside of the file-store.
# If that *also* doesn't exist, then we can return a 404.
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
report_name = construct_query_history_report_name(request_id)
has_file = file_store.has_file(
@@ -385,7 +385,7 @@ def download_query_history_csv(
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
report_name = construct_query_history_report_name(request_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
has_file = file_store.has_file(
file_id=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -53,7 +53,7 @@ def read_usage_report(
db_session: Session = Depends(get_session),
) -> Response:
try:
file = get_usage_report_data(db_session, report_name)
file = get_usage_report_data(report_name)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -112,7 +112,7 @@ def create_new_usage_report(
period: tuple[datetime, datetime] | None,
) -> UsageReportMetadata:
report_id = str(uuid.uuid4())
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
messages_file_id = generate_chat_messages_report(
db_session, file_store, report_id, period

View File

@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
store_ee_settings(final_enterprise_settings)
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
def _seed_logo(logo_path: str | None) -> None:
if logo_path:
logger.notice("Uploading logo")
upload_logo(db_session=db_session, file=logo_path)
upload_logo(file=logo_path)
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
@@ -245,7 +245,7 @@ def seed_db() -> None:
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_logo(seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)

View File

@@ -10,10 +10,12 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
@@ -47,6 +49,26 @@ def gate_product(
return ProductGatingResponse(updated=False, error=str(e))
@router.post("/product-gating/full-sync")
def gate_product_full_sync(
product_gating_request: ProductGatingFullSyncRequest,
_: None = Depends(control_plane_dep),
) -> ProductGatingResponse:
"""
Bulk operation to overwrite the entire gated tenant set.
This replaces all currently gated tenants with the provided list.
Gated tenants are not available to access the product and will be
directed to the billing page when their subscription has ended.
"""
try:
overwrite_full_gated_set(product_gating_request.gated_tenant_ids)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate products during full sync")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),

View File

@@ -19,6 +19,10 @@ class ProductGatingRequest(BaseModel):
application_status: ApplicationStatus
class ProductGatingFullSyncRequest(BaseModel):
gated_tenant_ids: list[str]
class SubscriptionStatusResponse(BaseModel):
subscribed: bool

View File

@@ -16,10 +16,6 @@ logger = setup_logger()
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
# Store the full status
status_key = f"tenant:{tenant_id}:status"
redis_client.set(status_key, status.value)
# Maintain the GATED_ACCESS set
if status == ApplicationStatus.GATED_ACCESS:
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
@@ -46,6 +42,25 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
raise
def overwrite_full_gated_set(tenant_ids: list[str]) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
pipeline = redis_client.pipeline()
# using pipeline doesn't automatically add the tenant_id prefix
full_gated_set_key = f"{ONYX_CLOUD_TENANT_ID}:{GATED_TENANTS_KEY}"
# Clear the existing set
pipeline.delete(full_gated_set_key)
# Add all tenant IDs to the set and set their status
for tenant_id in tenant_ids:
pipeline.sadd(full_gated_set_key, tenant_id)
# Execute all commands at once
pipeline.execute()
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -203,6 +203,8 @@ def generate_simple_sql(
if state.kg_entity_temp_view_name is None:
raise ValueError("kg_entity_temp_view_name is not set")
sql_statement_display: str | None = None
## STEP 3 - articulate goals
stream_write_step_activities(writer, _KG_STEP_NR)
@@ -381,7 +383,18 @@ def generate_simple_sql(
raise e
logger.debug(f"A3 - sql_statement after correction: {sql_statement}")
# display sql statement with view names replaced by general view names
sql_statement_display = sql_statement.replace(
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
)
sql_statement_display = sql_statement_display.replace(
state.kg_rel_temp_view_name, "<your_relationship_view_name>"
)
sql_statement_display = sql_statement_display.replace(
state.kg_entity_temp_view_name, "<your_entity_view_name>"
)
logger.debug(f"A3 - sql_statement after correction: {sql_statement_display}")
# Get SQL for source documents
@@ -409,7 +422,20 @@ def generate_simple_sql(
"relationship_table", rel_temp_view
)
logger.debug(f"A3 source_documents_sql: {source_documents_sql}")
if source_documents_sql:
source_documents_sql_display = source_documents_sql.replace(
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
)
source_documents_sql_display = source_documents_sql_display.replace(
state.kg_rel_temp_view_name, "<your_relationship_view_name>"
)
source_documents_sql_display = source_documents_sql_display.replace(
state.kg_entity_temp_view_name, "<your_entity_view_name>"
)
else:
source_documents_sql_display = "(No source documents SQL generated)"
logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}")
scalar_result = None
query_results = None
@@ -435,7 +461,13 @@ def generate_simple_sql(
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
except Exception as e:
# TODO: raise error on frontend
logger.error(f"Error executing SQL query: {e}")
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
raise e
@@ -459,8 +491,14 @@ def generate_simple_sql(
for source_document_result in query_source_document_results
]
except Exception as e:
# No stopping here, the individualized SQL query is not mandatory
# TODO: raise error on frontend
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
logger.error(f"Error executing Individualized SQL query: {e}")
else:
@@ -493,11 +531,11 @@ def generate_simple_sql(
if reasoning:
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
if main_sql_statement:
if sql_statement_display:
stream_write_step_answer_explicit(
writer,
step_nr=_KG_STEP_NR,
answer=f" \n Generated SQL: {main_sql_statement}",
answer=f" \n Generated SQL: {sql_statement_display}",
)
stream_close_step_answer(writer, _KG_STEP_NR)

View File

@@ -51,7 +51,6 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)
@@ -128,6 +127,7 @@ def choose_tool(
override_kwargs: SearchToolOverrideKwargs = (
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
)
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder

View File

@@ -174,7 +174,6 @@ def get_test_config(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
document_pruning_config=document_pruning_config,
structured_response_format=None,
)
@@ -198,7 +197,7 @@ def get_test_config(
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
document_pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,

View File

@@ -24,13 +24,14 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatt
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -39,6 +40,7 @@ from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import PlainFormatter
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -92,7 +94,13 @@ def on_task_prerun(
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
pass
# Reset any per-task logging context so that prefixes (e.g. pruning_ctx)
# from a previous task executed in the same worker process do not leak
# into the next task's log messages. This fixes incorrect [CC Pair:/Index Attempt]
# prefixes observed when a pruning task finishes and an indexing task
# runs in the same process.
LoggerContextVars.reset()
def on_task_postrun(
@@ -145,8 +153,11 @@ def on_task_postrun(
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
# NOTE: we want to remove the `Redis*` classes, prefer to just have functions to
# do these things going forward. In short, things should generally be like the doc
# sync task rather than the others below
if task_id.startswith(DOCUMENT_SYNC_PREFIX):
r.srem(DOCUMENT_SYNC_TASKSET_KEY, task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
@@ -470,7 +481,8 @@ class TenantContextFilter(logging.Filter):
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id:
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
# Match the 8 character tenant abbreviation used in OnyxLoggingAdapter
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:8]
record.name = f"[t:{tenant_id}]"
else:
record.name = ""

View File

@@ -0,0 +1,102 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
]
)

View File

@@ -12,7 +12,7 @@ from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -21,7 +21,7 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.indexing")
celery_app.config_from_object("onyx.background.celery.configs.docprocessing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@@ -60,7 +60,7 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
@@ -108,6 +108,6 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -116,6 +116,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -9,6 +9,7 @@ from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.result import AsyncResult
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@@ -18,9 +19,7 @@ from redis.lock import Lock as RedisLock
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.indexing.utils import (
get_unfenced_index_attempt_ids,
)
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
@@ -29,9 +28,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -156,7 +153,10 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
RedisGlobalConnectorCredentialPair.reset_all(r)
# NOTE: we want to remove the `Redis*` classes, prefer to just have functions
# This is the preferred way to do this going forward
reset_document_sync(r)
RedisDocumentSet.reset_all(r)
RedisUserGroup.reset_all(r)
RedisConnectorDelete.reset_all(r)
@@ -167,24 +167,50 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
# This uses database coordination instead of Redis fencing
with get_session_with_current_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
# Get potentially orphaned attempts (those with active status and task IDs)
potentially_orphaned_ids = IndexingCoordination.get_orphaned_index_attempt_ids(
db_session
)
for attempt_id in potentially_orphaned_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
# handle case where not started or docfetching is done but indexing is not
if (
not attempt
or not attempt.celery_task_id
or attempt.total_batches is not None
):
continue
failure_reason = (
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
logger.exception(
f"Marking attempt {attempt.id} as canceled due to validation error 2"
)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
# Check if the Celery task actually exists
try:
result: AsyncResult = AsyncResult(attempt.celery_task_id)
# If the task is not in PENDING state, it exists in Celery
if result.state != "PENDING":
continue
# Task is orphaned - mark as failed
failure_reason = (
f"Orphaned index attempt found on startup - Celery task not found: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id} "
f"celery_task_id={attempt.celery_task_id}"
)
logger.warning(failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
except Exception:
# If we can't check the task status, be conservative and continue
logger.warning(
f"Could not verify Celery task status on startup for attempt {attempt.id}, "
f"task_id={attempt.celery_task_id}"
)
@worker_ready.connect
@@ -291,7 +317,7 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -26,7 +26,7 @@ def celery_get_unacked_length(r: Redis) -> int:
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queue are "prefetched", so this gives
Unacked entries belonging to the indexing queues are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()

View File

@@ -0,0 +1,22 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Docfetching worker configuration
worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,5 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
from onyx.configs.app_configs import CELERY_WORKER_DOCPROCESSING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -24,6 +24,6 @@ task_acks_late = shared_config.task_acks_late
# which means a duplicate run might change the task state unexpectedly
# task_track_started = True
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_concurrency = CELERY_WORKER_DOCPROCESSING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -100,24 +100,6 @@ beat_task_templates: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,

View File

@@ -40,9 +40,11 @@ from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
@@ -69,13 +71,21 @@ def revoke_tasks_blocking_deletion(
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=redis_connector.cc_pair_id,
search_settings_id=search_settings.id,
limit=1,
db_session=db_session,
)
if (
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
and recent_index_attempts[0].celery_task_id
):
app.control.revoke(recent_index_attempts[0].celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
f"Revoked indexing task {recent_index_attempts[0].celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
@@ -281,8 +291,16 @@ def try_generate_document_cc_pair_cleanup_tasks(
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings.id,
limit=1,
db_session=db_session,
)
if (
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
):
raise TaskDependencyError(
"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "

View File

@@ -0,0 +1,637 @@
import multiprocessing
import os
import time
import traceback
from http import HTTPStatus
from time import sleep
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
from onyx.background.celery.tasks.docprocessing.tasks import ConnectorIndexingLogBuilder
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
from onyx.background.celery.tasks.models import DocProcessingContext
from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus
from onyx.background.celery.tasks.models import SimpleJobResult
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
def _verify_indexing_attempt(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
) -> None:
"""
Verify that the indexing attempt exists and is in the correct state.
"""
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"docfetching_task - IndexAttempt not found: attempt_id={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
if attempt.connector_credential_pair_id != cc_pair_id:
raise SimpleJobException(
f"docfetching_task - CC pair mismatch: "
f"expected={cc_pair_id} actual={attempt.connector_credential_pair_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
if attempt.search_settings_id != search_settings_id:
raise SimpleJobException(
f"docfetching_task - Search settings mismatch: "
f"expected={search_settings_id} actual={attempt.search_settings_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
if attempt.status not in [
IndexingStatus.NOT_STARTED,
IndexingStatus.IN_PROGRESS,
]:
raise SimpleJobException(
f"docfetching_task - Invalid attempt status: "
f"attempt_id={index_attempt_id} status={attempt.status}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
# Check for cancellation
if IndexingCoordination.check_cancellation_requested(
db_session, index_attempt_id
):
raise SimpleJobException(
f"docfetching_task - Cancellation requested: attempt_id={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
logger.info(
f"docfetching_task - IndexAttempt verified: "
f"attempt_id={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
def docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
"""
This function is run in a SimpleJob as a new process. It is responsible for validating
some stuff, but basically it just calls run_indexing_entrypoint.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Start heartbeat for this indexing attempt
heartbeat_thread, stop_event = start_heartbeat(index_attempt_id)
try:
_docfetching_task(
app, index_attempt_id, cc_pair_id, search_settings_id, is_ee, tenant_id
)
finally:
stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting
def _docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector.new_index(search_settings_id)
# TODO: remove all fences, cause all signals to be set in postgres
if redis_connector.delete.fenced:
raise SimpleJobException(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
)
if redis_connector.stop.fenced:
raise SimpleJobException(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
# Verify the indexing attempt exists and is valid
# This replaces the Redis fence payload waiting
_verify_indexing_attempt(index_attempt_id, cc_pair_id, search_settings_id)
try:
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"Index attempt not found: index_attempt={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise SimpleJobException(
f"cc_pair not found: cc_pair={cc_pair_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
# define a callback class
callback = IndexingCallback(
redis_connector,
)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
app,
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
except ConnectorValidationError:
raise SimpleJobException(
f"Indexing task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# special bulletproofing ... truncate long exception messages
# for exception types that require more args, this will fail
# thus the try/except
try:
sanitized_e = type(e)(str(e)[:1024])
sanitized_e.__traceback__ = e.__traceback__
raise sanitized_e
except Exception:
raise e
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
os._exit(0) # ensure process exits cleanly
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
if job.process:
result.exit_code = job.process.exitcode
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
)
else:
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
result.exception_str = job.exception()
return result
@shared_task(
name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def docfetching_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
) -> None:
"""
This task is the entrypoint for the full indexing pipeline, which is composed of two tasks:
docfetching and docprocessing.
This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task.
This task spawns a new process for a new scheduled index attempt. That
new process (which runs the docfetching_task function) does the following:
1) determines parameters of the indexing attempt (which connector indexing function to run,
start and end time, from prev checkpoint or not), then run that connector. Specifically,
connectors are responsible for reading data from an outside source and converting it to Onyx documents.
At the moment these two steps (reading external data and converting to an Onyx document)
are not parallelized in most connectors; that's a subject for future work.
Each document batch produced by step 1 is stored in the file store, and a docprocessing task is spawned
to process it. docprocessing involves the steps listed below.
2) upserts documents to postgres (index_doc_batch_prepare)
3) chunks each document (optionally adds context for contextual rag)
4) embeds chunks (embed_chunks_with_failure_handling) via a call to the model server
5) write chunks to vespa (write_chunks_to_vector_db_with_backoff)
6) update document and indexing metadata in postgres
7) pulls all document IDs from the source and compares those IDs to locally stored documents and deletes
all locally stored IDs missing from the most recently pulled document ID list
Some important notes:
Invariants:
- docfetching proxy tasks are spawned by check_for_indexing. The proxy then runs the docfetching_task wrapped in a watchdog.
The watchdog is responsible for monitoring the docfetching_task and marking the index attempt as failed
if it is not making progress.
- All docprocessing tasks are spawned by a docfetching task.
- all docfetching tasks, docprocessing tasks, and document batches in the file store are
associated with a specific index attempt.
- the index attempt status is the source of truth for what is currently happening with the index attempt.
It is coupled with the creation/running of docfetching and docprocessing tasks as much as possible.
How we deal with failures/ partial indexing:
- non-checkpointed connectors/ new runs in general => delete the old document batches from the file store and do the new run
- checkpointed connectors + resuming from checkpoint => reissue the old document batches and do a new run
Misc:
- most inter-process communication is handled in postgres, some is still in redis and we're trying to remove it
- Heartbeat spawned in docfetching and docprocessing is how check_for_indexing monitors liveliness
- progress based liveliness check: if nothing is done in 3-6 hours, mark the attempt as failed
- TODO: task level timeouts (i.e. a connector stuck in an infinite loop)
Comments below are from the old version and some may no longer be valid.
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
Some more Richard notes:
celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
# TODO: remove dependence on Redis
start = time.monotonic()
result = SimpleJobResult()
ctx = DocProcessingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
task_logger.info(
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}")
job = client.submit(
docfetching_task,
self.app,
index_attempt_id,
cc_pair_id,
search_settings_id,
global_version.is_ee_version(),
tenant_id,
)
if not job or not job.process:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
return
# Ensure the process has moved out of the starting state
num_waits = 0
while True:
if num_waits > 15:
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
job.release()
return
if job.process.is_alive() or job.process.exitcode is not None:
break
sleep(1)
num_waits += 1
task_logger.info(
log_builder.build(
"Indexing watchdog - spawn succeeded",
pid=str(job.process.pid),
)
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
eager_load_cc_pair=True,
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
while True:
sleep(5)
time.monotonic()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
# No need to expose full stack trace for validation errors
result.exception_str = str(e)
else:
result.exception_str = traceback.format_exc()
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_current_tenant() as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
normalized_exception_str = "None"
if result.exception_str:
normalized_exception_str = result.exception_str.replace(
"\n", "\\n"
).replace('"', '\\"')
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=f'"{normalized_exception_str}"',
elapsed=f"{elapsed:.2f}s",
)
)
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
)
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)

View File

@@ -0,0 +1,36 @@
import threading
from sqlalchemy import update
from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import IndexAttempt
def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]:
"""Start a heartbeat thread for the given index attempt"""
stop_event = threading.Event()
def heartbeat_loop() -> None:
while not stop_event.wait(INDEXING_WORKER_HEARTBEAT_INTERVAL):
try:
with get_session_with_current_tenant() as db_session:
db_session.execute(
update(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1)
)
db_session.commit()
except Exception:
# Silently continue if heartbeat fails
pass
thread = threading.Thread(target=heartbeat_loop, daemon=True)
thread.start()
return thread, stop_event
def stop_heartbeat(thread: threading.Thread, stop_event: threading.Event) -> None:
"""Stop the heartbeat thread"""
stop_event.set()
thread.join(timeout=5) # Wait up to 5 seconds for clean shutdown

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,8 @@
import time
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
@@ -12,8 +10,6 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
@@ -21,27 +17,19 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import setup_logger
@@ -50,54 +38,6 @@ logger = setup_logger()
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
class IndexingCallbackBase(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
@@ -123,10 +63,9 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
return False
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
def progress(self, tag: str, amount: int) -> None:
"""Amount isn't used yet."""
@@ -171,186 +110,28 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
raise
class IndexingCallback(IndexingCallbackBase):
# NOTE: we're in the process of removing all fences from indexing; this will
# eventually no longer be used. For now, it is used only for connector pausing.
class IndexingCallback(IndexingHeartbeatInterface):
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
redis_connector_index: RedisConnectorIndex,
):
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
self.redis_connector = redis_connector
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
def should_stop(self) -> bool:
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
# included to satisfy old interface
def progress(self, tag: str, amount: int) -> None:
self.redis_connector_index.set_active()
self.redis_connector_index.set_connector_active()
super().progress(tag, amount)
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
pass
def validate_indexing_fence(
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
f"index_attempt={payload.index_attempt_id}",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed: "
f"index_attempt={payload.index_attempt_id}",
)
redis_connector_index.reset()
return
def validate_indexing_fences(
tenant_id: str,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
"""Validates all indexing fences for this tenant ... aka makes sure
indexing tasks sent to celery are still in flight.
"""
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
continue
with get_session_with_current_tenant() as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
r_celery,
db_session,
)
lock_beat.reacquire()
return
# NOTE: The validate_indexing_fence and validate_indexing_fences functions have been removed
# as they are no longer needed with database-based coordination. The new validation is
# handled by validate_active_indexing_attempts in the main indexing tasks module.
def is_in_repeated_error_state(
@@ -414,10 +195,12 @@ def should_index(
)
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
task_logger.info(
f"_should_index: "
f"cc_pair={cc_pair.id} "
f"connector={cc_pair.connector_id} "
f"refresh_freq={connector.refresh_freq}"
)
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
@@ -517,7 +300,7 @@ def should_index(
return True
def try_creating_indexing_task(
def try_creating_docfetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
@@ -531,10 +314,11 @@ def try_creating_indexing_task(
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
Now uses database-based coordination instead of Redis fencing.
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
@@ -547,61 +331,42 @@ def try_creating_indexing_task(
if not acquired:
return None
redis_connector_index: RedisConnectorIndex
index_attempt_id = None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
# Basic status checks
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
# Generate custom task ID for tracking
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
# set a basic fence to start
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
# and cleans them up
# therefore we must create the attempt and the task after the fence goes up
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
# Try to create a new index attempt using database coordination
# This replaces the Redis fencing mechanism
index_attempt_id = IndexingCoordination.try_create_index_attempt(
db_session=db_session,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
celery_task_id=custom_task_id,
from_beginning=reindex,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
if index_attempt_id is None:
# Another indexing attempt is already running
return None
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_INDEXING
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
)
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
# Send the task to Celery
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -613,14 +378,18 @@ def try_creating_indexing_task(
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
task_logger.info(
f"Created docfetching task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id} "
f"attempt_id={index_attempt_id} "
f"celery_task_id={custom_task_id}"
)
return index_attempt_id
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
@@ -628,9 +397,10 @@ def try_creating_indexing_task(
f"search_settings={search_settings.id}"
)
# Clean up on failure
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
mark_attempt_failed(index_attempt_id, db_session)
return None
finally:
if lock.owned():

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,110 @@
from enum import Enum
from pydantic import BaseModel
class DocProcessingContext(BaseModel):
tenant_id: str
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
SPAWN_NOT_ALIVE = (
"spawn_not_alive" # spawn succeeded but process did not come alive
)
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
FENCE_READINESS_TIMEOUT = (
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
)
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
INDEX_ATTEMPT_MISMATCH = (
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_VALIDATION_ERROR = (
"connector_validation_error" # the connector validation failed
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
# consolidate once we know more
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@property
def code(self) -> int:
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
}
return _ENUM_TO_CODE[self]
@classmethod
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
}
if code in _CODE_TO_ENUM:
return _CODE_TO_ENUM[code]
return IndexingWatchdogTerminalStatus.UNDEFINED
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None

View File

@@ -147,7 +147,7 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"indexing_queue_length": "indexing",
"docprocessing_queue_length": "docprocessing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
@@ -882,7 +882,13 @@ def monitor_celery_queues_helper(
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
n_user_files_indexing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
@@ -896,14 +902,20 @@ def monitor_celery_queues_helper(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_indexing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
n_docfetching_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.DOCPROCESSING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(n_indexing_prefetched)} "
f"docfetching={n_docfetching} "
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
f"docprocessing={n_docprocessing} "
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
f"user_files_indexing={n_user_files_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "

View File

@@ -22,7 +22,7 @@ from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallbackBase
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -138,8 +138,11 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
# if never pruned, use the connector creation time. We could also
# compute the completion time of the first successful index attempt, but
# that is a reasonably heavy operation. This is a reasonable approximation —
# in the worst case, we'll prune a little bit earlier than we should.
last_pruned = cc_pair.connector.time_created
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
@@ -173,6 +176,9 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
task_logger.info("Checking for pruning due")
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
@@ -187,15 +193,18 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
cc_pair_id=cc_pair_id,
)
if not cc_pair:
logger.error(f"CC pair not found: {cc_pair_id}")
continue
if not _is_pruning_due(cc_pair):
logger.info(f"CC pair not due for pruning: {cc_pair_id}")
continue
payload_id = try_creating_prune_generator_task(
self.app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
logger.info(f"Pruning not created: {cc_pair_id}")
continue
task_logger.info(
@@ -264,6 +273,8 @@ def try_creating_prune_generator_task(
is used to trigger prunes immediately, e.g. via the web ui.
"""
logger.info(f"try_creating_prune_generator_task: cc_pair={cc_pair.id}")
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not ALLOW_SIMULTANEOUS_PRUNING:
@@ -287,18 +298,30 @@ def try_creating_prune_generator_task(
try:
# skip pruning if already pruning
if redis_connector.prune.fenced:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} already pruning"
)
return None
# skip pruning if the cc_pair is deleting
if redis_connector.delete.fenced:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} deleting"
)
return None
# skip pruning if doc permissions sync is running
if redis_connector.permissions.fenced:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} permissions sync running"
)
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} deleting"
)
return None
# add a long running generator task to the queue
@@ -441,7 +464,7 @@ def connector_pruning_generator_task(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)

View File

@@ -0,0 +1,178 @@
import time
from typing import cast
from uuid import uuid4
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.document import construct_document_id_select_by_needs_sync
from onyx.db.document import count_documents_by_needs_sync
from onyx.utils.logger import setup_logger
# Redis keys for document sync tracking
DOCUMENT_SYNC_PREFIX = "documentsync"
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
logger = setup_logger()
def is_document_sync_fenced(r: Redis) -> bool:
"""Check if document sync tasks are currently in progress."""
return bool(r.exists(DOCUMENT_SYNC_FENCE_KEY))
def get_document_sync_payload(r: Redis) -> int | None:
"""Get the initial number of tasks that were created."""
bytes_result = r.get(DOCUMENT_SYNC_FENCE_KEY)
if bytes_result is None:
return None
return int(cast(int, bytes_result))
def get_document_sync_remaining(r: Redis) -> int:
"""Get the number of tasks still pending completion."""
return cast(int, r.scard(DOCUMENT_SYNC_TASKSET_KEY))
def set_document_sync_fence(r: Redis, payload: int | None) -> None:
"""Set up the fence and register with active fences."""
if payload is None:
r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
r.delete(DOCUMENT_SYNC_FENCE_KEY)
return
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
def delete_document_sync_taskset(r: Redis) -> None:
"""Clear the document sync taskset."""
r.delete(DOCUMENT_SYNC_TASKSET_KEY)
def reset_document_sync(r: Redis) -> None:
"""Reset all document sync tracking data."""
r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
r.delete(DOCUMENT_SYNC_TASKSET_KEY)
r.delete(DOCUMENT_SYNC_FENCE_KEY)
def generate_document_sync_tasks(
r: Redis,
max_tasks: int,
celery_app: Celery,
db_session: Session,
lock: RedisLock,
tenant_id: str,
) -> tuple[int, int]:
"""Generate sync tasks for all documents that need syncing.
Args:
r: Redis client
max_tasks: Maximum number of tasks to generate
celery_app: Celery application instance
db_session: Database session
lock: Redis lock for coordination
tenant_id: Tenant identifier
Returns:
tuple[int, int]: (tasks_generated, total_docs_found)
"""
last_lock_time = time.monotonic()
num_tasks_sent = 0
num_docs = 0
# Get all documents that need syncing
stmt = construct_document_id_select_by_needs_sync()
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
# Reacquire lock periodically to prevent timeout
if current_time - last_lock_time >= (CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4):
lock.reacquire()
last_lock_time = current_time
num_docs += 1
# Create a unique task ID
custom_task_id = f"{DOCUMENT_SYNC_PREFIX}_{uuid4()}"
# Add to the tracking taskset in Redis BEFORE creating the celery task
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
# Create the Celery task
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
)
num_tasks_sent += 1
if num_tasks_sent >= max_tasks:
break
return num_tasks_sent, num_docs
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
max_tasks: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
) -> int | None:
# the fence is up, do nothing
if is_document_sync_fenced(r):
return None
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
logger.info("No stale documents found. Skipping sync tasks generation.")
return None
logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch."
)
logger.info("generate_document_sync_tasks starting for all documents.")
# Generate all tasks in one pass
result = generate_document_sync_tasks(
r, max_tasks, celery_app, db_session, lock_beat, tenant_id
)
if result is None:
return None
tasks_generated, total_docs = result
if tasks_generated >= max_tasks:
logger.info(
f"generate_document_sync_tasks reached the task generation limit: "
f"tasks_generated={tasks_generated} max_tasks={max_tasks}"
)
else:
logger.info(
f"generate_document_sync_tasks finished for all documents. "
f"tasks_generated={tasks_generated} total_docs_found={total_docs}"
)
set_document_sync_fence(r, tasks_generated)
return tasks_generated

View File

@@ -20,14 +20,19 @@ from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocument
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_FENCE_KEY
from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_payload
from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_remaining
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.background.celery.tasks.vespa.document_sync import (
try_generate_stale_document_sync_tasks,
)
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import count_documents_by_needs_sync
from onyx.db.document import get_document
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import delete_document_set
@@ -47,10 +52,6 @@ from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
@@ -166,8 +167,11 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
continue
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
# NOTE: removing the "Redis*" classes, prefer to just have functions to
# do these things going forward. In short, things should generally be like the doc
# sync task rather than the others
if key_str == DOCUMENT_SYNC_FENCE_KEY:
monitor_document_sync_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
@@ -203,82 +207,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
return True
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
max_tasks: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
) -> int | None:
# the fence is up, do nothing
redis_global_ccpair = RedisGlobalConnectorCredentialPair(r)
if redis_global_ccpair.fenced:
return None
redis_global_ccpair.delete_taskset()
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
tasks_remaining = max_tasks
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(
tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
continue
if result[1] == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
)
total_tasks_generated += result[0]
tasks_remaining -= result[0]
if tasks_remaining <= 0:
break
if tasks_remaining <= 0:
task_logger.info(
f"RedisConnector.generate_tasks reached the task generation limit: "
f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}"
)
else:
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
redis_global_ccpair.set_fence(total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
celery_app: Celery,
document_set_id: int,
@@ -433,19 +361,18 @@ def try_generate_user_group_sync_tasks(
return tasks_generated
def monitor_connector_taskset(r: Redis) -> None:
redis_global_ccpair = RedisGlobalConnectorCredentialPair(r)
initial_count = redis_global_ccpair.payload
def monitor_document_sync_taskset(r: Redis) -> None:
initial_count = get_document_sync_payload(r)
if initial_count is None:
return
remaining = redis_global_ccpair.get_remaining()
remaining = get_document_sync_remaining(r)
task_logger.info(
f"Stale document sync progress: remaining={remaining} initial={initial_count}"
f"Document sync progress: remaining={remaining} initial={initial_count}"
)
if remaining == 0:
redis_global_ccpair.reset()
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
reset_document_sync(r)
task_logger.info(f"Successfully synced all documents. count={initial_count}")
def monitor_document_set_taskset(

View File

@@ -10,7 +10,7 @@ set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.indexing import celery_app
from onyx.background.celery.apps.docfetching import celery_app
return celery_app

View File

@@ -0,0 +1,18 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.docprocessing import celery_app
return celery_app
app = get_app()

View File

@@ -33,7 +33,7 @@ def save_checkpoint(
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.save_file(
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
@@ -52,11 +52,11 @@ def save_checkpoint(
def load_checkpoint(
db_session: Session, index_attempt_id: int, connector: BaseConnector
index_attempt_id: int, connector: BaseConnector
) -> ConnectorCheckpoint:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointedConnector):
@@ -71,7 +71,7 @@ def get_latest_valid_checkpoint(
window_start: datetime,
window_end: datetime,
connector: BaseConnector,
) -> ConnectorCheckpoint:
) -> tuple[ConnectorCheckpoint, bool]:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
@@ -83,7 +83,7 @@ def get_latest_valid_checkpoint(
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
# where we make no progress. Only do this if we have had at least
# _NUM_RECENT_ATTEMPTS_TO_CONSIDER completed attempts.
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
if len(checkpoint_candidates) >= _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
had_any_progress = False
for candidate in checkpoint_candidates:
if (
@@ -99,7 +99,7 @@ def get_latest_valid_checkpoint(
f"found for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return connector.build_dummy_checkpoint()
return connector.build_dummy_checkpoint(), False
# filter out any candidates that don't meet the criteria
checkpoint_candidates = [
@@ -140,11 +140,10 @@ def get_latest_valid_checkpoint(
logger.info(
f"No valid checkpoint found for cc_pair={cc_pair_id}. Starting from scratch."
)
return checkpoint
return checkpoint, False
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
connector=connector,
)
@@ -153,14 +152,14 @@ def get_latest_valid_checkpoint(
f"Failed to load checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Falling back to default checkpoint."
)
return checkpoint
return checkpoint, False
logger.info(
f"Using checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
f"{previous_checkpoint}"
)
return previous_checkpoint
return previous_checkpoint, True
def get_index_attempts_with_old_checkpoints(
@@ -201,7 +200,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None

View File

@@ -1,3 +1,4 @@
import sys
import time
import traceback
from collections import defaultdict
@@ -5,7 +6,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from pydantic import BaseModel
from celery import Celery
from sqlalchemy.orm import Session
from onyx.access.access import source_should_fetch_permissions_during_indexing
@@ -18,18 +19,25 @@ from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -49,13 +57,16 @@ from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
@@ -68,7 +79,7 @@ from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
logger = setup_logger(propagate=False)
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
@@ -146,6 +157,10 @@ def _get_connector_runner(
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
cleaned_batch = []
for doc in doc_batch:
if sys.getsizeof(doc) > MAX_FILE_SIZE_BYTES:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
@@ -180,25 +195,11 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
return cleaned_batch
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class RunIndexingContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
db_session_temp: Session,
cc_pair_id: int,
search_settings_status: IndexModelStatus,
index_attempt_id: int,
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
@@ -206,27 +207,34 @@ def _check_connector_and_attempt_status(
"""
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
raise RuntimeError(f"CC pair {cc_pair_id} not found in DB.")
if (
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
and search_settings_status != IndexModelStatus.FUTURE
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
raise RuntimeError("Connector was disabled mid run")
raise ConnectorStopSignal(f"Connector {cc_pair_loop.status.value.lower()}")
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_loop:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if index_attempt_loop.status == IndexingStatus.CANCELED:
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
f"Index Attempt is not running, status is {index_attempt_loop.status}"
)
if index_attempt_loop.celery_task_id is None:
raise RuntimeError(f"Index attempt {index_attempt_id} has no celery task id")
# TODO: delete from here if ends up unused
def _check_failure_threshold(
total_failures: int,
document_count: int,
@@ -257,6 +265,9 @@ def _check_failure_threshold(
)
# NOTE: this is the old run_indexing function that the new decoupled approach
# is based on. Leaving this for comparison purposes, but if you see this comment
# has been here for >1 month, please delete this function.
def _run_indexing(
db_session: Session,
index_attempt_id: int,
@@ -271,7 +282,12 @@ def _run_indexing(
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
index_attempt_start = get_index_attempt(
db_session_temp,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
@@ -292,7 +308,7 @@ def _run_indexing(
index_attempt_start.connector_credential_pair.last_successful_index_time
is not None
)
ctx = RunIndexingContext(
ctx = DocExtractionContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
@@ -317,6 +333,7 @@ def _run_indexing(
and (from_beginning or not has_successful_attempt)
),
search_settings_status=index_attempt_start.search_settings.status,
doc_extraction_complete_batch_num=None,
)
last_successful_index_poll_range_end = (
@@ -384,19 +401,6 @@ def _run_indexing(
httpx_client=HttpxPool.get("vespa"),
)
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
@@ -416,7 +420,9 @@ def _run_indexing(
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
index_attempt = get_index_attempt(
db_session_temp, index_attempt_id, eager_load_cc_pair=True
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
@@ -439,7 +445,7 @@ def _run_indexing(
):
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
checkpoint, _ = get_latest_valid_checkpoint(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
@@ -496,7 +502,10 @@ def _run_indexing(
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
db_session_temp,
ctx.cc_pair_id,
ctx.search_settings_status,
index_attempt_id,
)
# save record of any failures at the connector level
@@ -554,7 +563,16 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
index_pipeline_result = indexing_pipeline(
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
@@ -815,6 +833,7 @@ def _run_indexing(
def run_indexing_entrypoint(
app: Celery,
index_attempt_id: int,
tenant_id: str,
connector_credential_pair_id: int,
@@ -832,7 +851,6 @@ def run_indexing_entrypoint(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_current_tenant() as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
@@ -846,18 +864,514 @@ def run_indexing_entrypoint(
credential_id = attempt.connector_credential_pair.credential_id
logger.info(
f"Indexing starting{tenant_str}: "
f"Docfetching starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
with get_session_with_current_tenant() as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
connector_document_extraction(
app,
index_attempt_id,
attempt.connector_credential_pair_id,
attempt.search_settings_id,
tenant_id,
callback,
)
logger.info(
f"Indexing finished{tenant_str}: "
f"Docfetching finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
def connector_document_extraction(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""Extract documents from connector and queue them for indexing pipeline processing.
This is the first part of the split indexing process that runs the connector
and extracts documents, storing them in the filestore for later processing.
"""
start_time = time.monotonic()
logger.info(
f"Document extraction starting: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"tenant={tenant_id}"
)
# Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint)
batch_storage = get_document_batch_storage(cc_pair_id, index_attempt_id)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt = None
last_batch_num = 0 # used to continue from checkpointing
# comes from _run_indexing
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found")
if index_attempt.search_settings is None:
raise ValueError("Search settings must be set for indexing")
# Clear the indexing trigger if it was set, to prevent duplicate indexing attempts
if index_attempt.connector_credential_pair.indexing_trigger is not None:
logger.info(
"Clearing indexing trigger: "
f"cc_pair={index_attempt.connector_credential_pair.id} "
f"trigger={index_attempt.connector_credential_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
index_attempt.connector_credential_pair.id, None, db_session
)
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
from_beginning = index_attempt.from_beginning
has_successful_attempt = (
index_attempt.connector_credential_pair.last_successful_index_time
is not None
)
earliest_index_time = (
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
)
should_fetch_permissions_during_indexing = (
index_attempt.connector_credential_pair.access_type == AccessType.SYNC
and source_should_fetch_permissions_during_indexing(db_connector.source)
and is_primary
# if we've already successfully indexed, let the doc_sync job
# take care of doc-level permissions
and (from_beginning or not has_successful_attempt)
)
# Set up time windows for polling
last_successful_index_poll_range_end = (
earliest_index_time
if from_beginning
else get_last_successful_attempt_poll_range_end(
cc_pair_id=cc_pair_id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
)
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_poll_range_end, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
db_session=db_session,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# set time range in db
index_attempt.poll_range_start = window_start
index_attempt.poll_range_end = window_end
db_session.commit()
# TODO: maybe memory tracer here
# Set up connector runner
connector_runner = _get_connector_runner(
db_session=db_session,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
include_permissions=should_fetch_permissions_during_indexing,
)
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling
# OR
# if the last attempt was successful
if index_attempt.from_beginning or (
most_recent_attempt and most_recent_attempt.status.is_successful()
):
logger.info(
f"Cleaning up all old batches for index attempt {index_attempt_id} before starting new run"
)
batch_storage.cleanup_all_batches()
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
logger.info(
f"Getting latest valid checkpoint for index attempt {index_attempt_id}"
)
checkpoint, resuming_from_checkpoint = get_latest_valid_checkpoint(
db_session=db_session,
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
# checkpoint resumption OR the connector already finished.
if (
isinstance(connector_runner.connector, CheckpointedConnector)
and resuming_from_checkpoint
) or (
most_recent_attempt
and most_recent_attempt.total_batches is not None
and not checkpoint.has_more
):
reissued_batch_count, completed_batches = reissue_old_batches(
batch_storage,
index_attempt_id,
cc_pair_id,
tenant_id,
app,
most_recent_attempt,
)
last_batch_num = reissued_batch_count + completed_batches
index_attempt.completed_batches = completed_batches
db_session.commit()
else:
logger.info(
f"Cleaning up all batches for index attempt {index_attempt_id} before starting new run"
)
# for non-checkpointed connectors, throw out batches from previous unsuccessful attempts
# because we'll be getting those documents again anyways.
batch_storage.cleanup_all_batches()
# Save initial checkpoint
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
try:
batch_num = last_batch_num # starts at 0 if no last batch
total_doc_batches_queued = 0
total_failures = 0
document_count = 0
# Main extraction loop
while checkpoint.has_more:
logger.info(
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback and callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
# will exception if the connector/index attempt is marked as paused/failed
with get_session_with_current_tenant() as db_session_tmp:
_check_connector_and_attempt_status(
db_session_tmp,
cc_pair_id,
index_attempt.search_settings.status,
index_attempt_id,
)
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session:
create_index_attempt_error(
index_attempt_id,
cc_pair_id,
failure,
db_session,
)
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
# Save checkpoint if provided
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing task, so if no batch we can just continue
if not document_batch:
continue
# Clean documents and create batch
doc_batch_cleaned = strip_null_characters(document_batch)
batch_description = []
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
memory_tracer.increment_and_maybe_trace()
# Store documents in storage
batch_storage.store_batch(batch_num, doc_batch_cleaned)
# Create processing task data
processing_batch_data = {
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": batch_num, # 0-indexed
}
# Queue document processing task
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
batch_num += 1
total_doc_batches_queued += 1
logger.info(
f"Queued document processing batch: "
f"batch_num={batch_num} "
f"docs={len(doc_batch_cleaned)} "
f"attempt={index_attempt_id}"
)
# Check checkpoint size periodically
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# Save latest checkpoint
# NOTE: checkpointing is used to track which batches have
# been sent to the filestore, NOT which batches have been fully indexed
# as it used to be.
with get_session_with_current_tenant() as db_session:
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
elapsed_time = time.monotonic() - start_time
logger.info(
f"Document extraction completed: "
f"attempt={index_attempt_id} "
f"batches_queued={total_doc_batches_queued} "
f"elapsed={elapsed_time:.2f}s"
)
# Set total batches in database to signal extraction completion.
# Used by check_for_indexing to determine if the index attempt is complete.
with get_session_with_current_tenant() as db_session:
IndexingCoordination.set_total_batches(
db_session=db_session,
index_attempt_id=index_attempt_id,
total_batches=batch_num,
)
except Exception as e:
logger.exception(
f"Document extraction failed: "
f"attempt={index_attempt_id} "
f"error={str(e)}"
)
# Do NOT clean up batches on failure; future runs will use those batches
# while docfetching will continue from the saved checkpoint if one exists
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if is_primary:
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {db_connector.id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=db_connector.id,
credential_id=db_credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
else:
with get_session_with_current_tenant() as db_session_temp:
# don't overwrite attempts that are already failed/canceled for another reason
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if index_attempt and index_attempt.status in [
IndexingStatus.CANCELED,
IndexingStatus.FAILED,
]:
logger.info(
f"Attempt {index_attempt_id} is already failed/canceled, skipping marking as failed."
)
raise e
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
raise e
finally:
memory_tracer.stop()
def reissue_old_batches(
batch_storage: DocumentBatchStorage,
index_attempt_id: int,
cc_pair_id: int,
tenant_id: str,
app: Celery,
most_recent_attempt: IndexAttempt | None,
) -> tuple[int, int]:
# When loading from a checkpoint, we need to start new docprocessing tasks
# tied to the new index attempt for any batches left over in the file store
old_batches = batch_storage.get_all_batches_for_cc_pair()
batch_storage.update_old_batches_to_new_index_attempt(old_batches)
for batch_id in old_batches:
logger.info(
f"Re-issuing docprocessing task for batch {batch_id} for index attempt {index_attempt_id}"
)
path_info = batch_storage.extract_path_info(batch_id)
if path_info is None:
continue
if path_info.cc_pair_id != cc_pair_id:
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs={
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": path_info.batch_num, # use same batch num as previously
},
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
# resume from the batch num of the last attempt. This should be one more
# than the last batch created by docfetching regardless of whether the batch
# is still in the filestore waiting for processing or not.
last_batch_num = len(old_batches) + recent_batches
logger.info(
f"Starting from batch {last_batch_num} due to "
f"re-issued batches: {old_batches}, completed batches: {recent_batches}"
)
return len(old_batches), recent_batches

View File

@@ -1,12 +1,7 @@
import csv
import json
import os
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import GraphConfig
@@ -16,9 +11,6 @@ from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_agent_search_graph
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import (
run_basic_graph as run_hackathon_graph,
) # You can create your own graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.chat.models import AgentAnswerPiece
@@ -30,11 +22,9 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.models import ToolCallFinalResult
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import RerankingDetails
@@ -54,190 +44,6 @@ logger = setup_logger()
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
"""
Calculate the score for a given position.
"""
if pos > max_acceptable_pos:
return 0
elif pos == 1:
return 1
elif pos == 2:
return 0.8
else:
return 4 / (pos + 5)
def _clean_doc_id_link(doc_link: str) -> str:
"""
Clean the google doc link.
"""
if "google.com" in doc_link:
if "/edit" in doc_link:
return "/edit".join(doc_link.split("/edit")[:-1])
elif "/view" in doc_link:
return "/view".join(doc_link.split("/view")[:-1])
else:
return doc_link
if "app.fireflies.ai" in doc_link:
return "?".join(doc_link.split("?")[:-1])
return doc_link
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
"""
Get the score of a document from the document results.
"""
match_pos = None
for pos, comp_doc in enumerate(doc_results, start=1):
clear_doc_id = _clean_doc_id_link(doc_id)
clear_comp_doc = _clean_doc_id_link(comp_doc)
if clear_doc_id == clear_comp_doc:
match_pos = pos
if match_pos is None:
return 0.0
return _calc_score_for_pos(match_pos)
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH):
"""
Append an empty line to the CSV file.
"""
_append_answer_to_csv("", "", csv_path)
def _append_ground_truth_to_csv(
query: str,
ground_truth_docs: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for doc_id in ground_truth_docs:
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
logger.debug("Appended score to csv file")
def _append_score_to_csv(
query: str,
score: float,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", "", score])
logger.debug("Appended score to csv file")
def _append_search_results_to_csv(
query: str,
doc_results: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the search results to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for pos, doc in enumerate(doc_results, start=1):
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
logger.debug("Appended search results to csv file")
def _append_answer_to_csv(
query: str,
answer: str,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", answer, ""])
logger.debug("Appended answer to csv file")
class Answer:
def __init__(
self,
@@ -328,9 +134,6 @@ class Answer:
@property
def processed_streamed_output(self) -> AnswerStream:
_HACKATHON_TEST_EXECUTION = False
if self._processed_stream is not None:
yield from self._processed_stream
return
@@ -351,117 +154,20 @@ class Answer:
)
):
run_langgraph = run_dc_graph
elif (
self.graph_config.inputs.persona
and self.graph_config.inputs.persona.description.startswith(
"Hackathon Test"
)
):
_HACKATHON_TEST_EXECUTION = True
run_langgraph = run_hackathon_graph
else:
run_langgraph = run_basic_graph
if _HACKATHON_TEST_EXECUTION:
stream = run_langgraph(self.graph_config)
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
if input_data.startswith("["):
input_type = "json"
input_list = json.loads(input_data)
else:
input_type = "list"
input_list = input_data.split(";")
num_examples_with_ground_truth = 0
total_score = 0.0
for question_num, question_data in enumerate(input_list):
ground_truth_docs = None
if input_type == "json":
question = question_data["question"]
ground_truth = question_data.get("ground_truth")
if ground_truth:
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
logger.info(f"Question {question_num}: {question}")
_append_ground_truth_to_csv(question, ground_truth_docs)
else:
continue
else:
question = question_data
self.graph_config.inputs.prompt_builder.raw_user_query = question
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
HumanMessage(
content=question, additional_kwargs={}, response_metadata={}
),
2,
)
self.graph_config.tooling.force_use_tool.force_use = True
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
llm_answer_segments: list[str] = []
doc_results: list[str] | None = None
for answer_piece in processed_stream:
if isinstance(answer_piece, OnyxAnswerPiece):
llm_answer_segments.append(answer_piece.answer_piece or "")
elif isinstance(answer_piece, ToolCallFinalResult):
doc_results = [x.get("link") for x in answer_piece.tool_result]
if doc_results:
_append_search_results_to_csv(question, doc_results)
_append_answer_to_csv(question, "".join(llm_answer_segments))
if ground_truth_docs and doc_results:
num_examples_with_ground_truth += 1
doc_score = 0.0
for doc_id in ground_truth_docs:
doc_score += _get_doc_score(doc_id, doc_results)
_append_score_to_csv(question, doc_score)
total_score += doc_score
self._processed_stream = processed_stream
if num_examples_with_ground_truth > 0:
comprehensive_score = total_score / num_examples_with_ground_truth
else:
comprehensive_score = 0
_append_empty_line()
_append_score_to_csv(question, comprehensive_score)
else:
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
self._processed_stream = processed_stream
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
@property
def llm_answer(self) -> str:

View File

@@ -309,7 +309,10 @@ class ContextualPruningConfig(DocumentPruningConfig):
def from_doc_pruning_config(
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
) -> "ContextualPruningConfig":
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
return cls(
num_chunk_multiple=num_chunk_multiple,
**doc_pruning_config.model_dump(),
)
class CitationConfig(BaseModel):
@@ -318,9 +321,6 @@ class CitationConfig(BaseModel):
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API

View File

@@ -128,17 +128,17 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchResponse,
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.internet_search.models import (
InternetSearchResponseSummary,
)
from onyx.tools.tool_implementations.internet_search.utils import (
internet_search_response_to_search_docs,
)
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -281,7 +281,7 @@ def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponse, packet.response)
internet_search_response = cast(InternetSearchResponseSummary, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
@@ -296,10 +296,10 @@ def _handle_internet_search_tool_response_summary(
]
return (
QADocsResponse(
rephrased_query=internet_search_response.revised_query,
rephrased_query=internet_search_response.query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.SEMANTIC,
predicted_search=SearchType.INTERNET,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
@@ -491,7 +491,7 @@ def _process_tool_response(
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
elif packet.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
@@ -725,9 +725,7 @@ def stream_chat_message_objects(
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids = new_msg_req.user_file_ids or []
@@ -906,7 +904,6 @@ def stream_chat_message_objects(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
)
@@ -936,6 +933,7 @@ def stream_chat_message_objects(
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
@@ -1012,6 +1010,7 @@ def stream_chat_message_objects(
tools=tools,
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(

View File

@@ -187,12 +187,8 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if (
self.new_messages_and_token_cnts
and isinstance(self.user_message_and_token_cnt[0].content, str)
and self.user_message_and_token_cnt[0].content.startswith("Refer")
):
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -11,6 +11,7 @@ from onyx.chat.models import (
)
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import compute_max_document_tokens
from onyx.configs.app_configs import MAX_FEDERATED_SECTIONS
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.context.search.models import InferenceChunk
@@ -67,6 +68,38 @@ def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
return combined_ranges
def _separate_federated_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
) -> tuple[list[InferenceSection], list[InferenceSection], list[bool] | None]:
"""
Separates out the first NUM_FEDERATED_SECTIONS federated sections to be spared from pruning.
Any remaining federated sections are treated as normal sections, and will get added if it
fits within the allocated context window. This is done as federated sections do not have
a score and would otherwise always get pruned.
"""
federated_sections: list[InferenceSection] = []
normal_sections: list[InferenceSection] = []
normal_section_relevance_list: list[bool] = []
for i, section in enumerate(sections):
if (
len(federated_sections) < MAX_FEDERATED_SECTIONS
and section.center_chunk.is_federated
):
federated_sections.append(section)
continue
normal_sections.append(section)
if section_relevance_list is not None:
normal_section_relevance_list.append(section_relevance_list[i])
return (
federated_sections[:MAX_FEDERATED_SECTIONS],
normal_sections,
normal_section_relevance_list if section_relevance_list is not None else None,
)
def _compute_limit(
prompt_config: PromptConfig,
llm_config: LLMConfig,
@@ -105,7 +138,7 @@ def _compute_limit(
return int(min(limit_options))
def reorder_sections(
def _reorder_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
) -> list[InferenceSection]:
@@ -113,11 +146,10 @@ def reorder_sections(
return sections
reordered_sections: list[InferenceSection] = []
if section_relevance_list is not None:
for selection_target in [True, False]:
for section, is_relevant in zip(sections, section_relevance_list):
if is_relevant == selection_target:
reordered_sections.append(section)
for selection_target in [True, False]:
for section, is_relevant in zip(sections, section_relevance_list):
if is_relevant == selection_target:
reordered_sections.append(section)
return reordered_sections
@@ -134,6 +166,7 @@ def _remove_sections_to_ignore(
def _apply_pruning(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
keep_sections: list[InferenceSection],
token_limit: int,
is_manually_selected_docs: bool,
use_sections: bool,
@@ -144,10 +177,22 @@ def _apply_pruning(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
sections = deepcopy(sections) # don't modify in place
# combine the section lists, making sure to add the keep_sections first
sections = deepcopy(keep_sections) + deepcopy(sections)
# build combined relevance list, treating the keep_sections as relevant
if section_relevance_list is not None:
section_relevance_list = [True] * len(keep_sections) + section_relevance_list
# map unique_id: relevance for final ordering step
section_id_to_relevance: dict[str, bool] = {}
if section_relevance_list is not None:
for sec, rel in zip(sections, section_relevance_list):
section_id_to_relevance[sec.center_chunk.unique_id] = rel
# re-order docs with all the "relevant" docs at the front
sections = reorder_sections(
sections = _reorder_sections(
sections=sections, section_relevance_list=section_relevance_list
)
# remove docs that are explicitly marked as not for QA
@@ -274,6 +319,14 @@ def _apply_pruning(
)
sections = [sections[0]]
# sort by relevance, then by score (as we added the keep_sections first)
sections.sort(
key=lambda s: (
not section_id_to_relevance.get(s.center_chunk.unique_id, True),
-(s.center_chunk.score or 0.0),
),
)
return sections
@@ -289,9 +342,16 @@ def prune_sections(
if section_relevance_list is not None:
assert len(sections) == len(section_relevance_list)
# get federated sections (up to NUM_FEDERATED_SECTIONS)
# TODO: if we can somehow score the federated sections well, we don't need this
federated_sections, normal_sections, normal_section_relevance_list = (
_separate_federated_sections(sections, section_relevance_list)
)
actual_num_chunks = (
contextual_pruning_config.max_chunks
* contextual_pruning_config.num_chunk_multiple
+ len(federated_sections)
if contextual_pruning_config.max_chunks
else None
)
@@ -307,8 +367,9 @@ def prune_sections(
)
return _apply_pruning(
sections=sections,
section_relevance_list=section_relevance_list,
sections=normal_sections,
section_relevance_list=normal_section_relevance_list,
keep_sections=federated_sections,
token_limit=token_limit,
is_manually_selected_docs=contextual_pruning_config.is_manually_selected_docs,
use_sections=contextual_pruning_config.use_sections, # Now default True

View File

@@ -35,6 +35,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
# Controls whether users can use User Knowledge (personal documents) in assistants
DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true"
# Controls whether to allow admin query history reports with:
# 1. associated user emails
# 2. anonymized user emails
@@ -308,25 +311,40 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT = 6
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
env_value = os.environ.get("CELERY_WORKER_DOCPROCESSING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_INDEXING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
env_value = str(CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = (
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1
try:
env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_DOCFETCHING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_DOCFETCHING_CONCURRENCY = (
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 1024
VESPA_SYNC_MAX_TASKS = 8192
DB_YIELD_PER_DEFAULT = 64
@@ -450,6 +468,11 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Default size threshold for SharePoint files (20MB)
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -481,6 +504,7 @@ LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
# Slack specific configs
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
MAX_SLACK_QUERY_EXPANSIONS = int(os.environ.get("MAX_SLACK_QUERY_EXPANSIONS", "5"))
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
@@ -654,6 +678,14 @@ except json.JSONDecodeError:
# LLM Model Update API endpoint
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
# Federated Search Configs
MAX_FEDERATED_SECTIONS = int(
os.environ.get("MAX_FEDERATED_SECTIONS", "5")
) # max no. of federated sections to always keep
MAX_FEDERATED_CHUNKS = int(
os.environ.get("MAX_FEDERATED_CHUNKS", "5")
) # max no. of chunks to retrieve per federated connector
#####
# Enterprise Edition Configs
#####
@@ -787,7 +819,3 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
# Forcing Vespa Language
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
)

View File

@@ -91,6 +91,10 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)

View File

@@ -65,7 +65,8 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
@@ -121,6 +122,8 @@ CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
# hard termination should always fire first if the connector is hung
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
# Heartbeat interval for indexing worker liveness detection
INDEXING_WORKER_HEARTBEAT_INTERVAL = 30 # seconds
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
@@ -186,10 +189,21 @@ class DocumentSource(str, Enum):
AIRTABLE = "airtable"
HIGHSPOT = "highspot"
IMAP = "imap"
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
class FederatedConnectorSource(str, Enum):
FEDERATED_SLACK = "federated_slack"
def to_non_federated_source(self) -> DocumentSource | None:
if self == FederatedConnectorSource.FEDERATED_SLACK:
return DocumentSource.SLACK
return None
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -320,9 +334,12 @@ class OnyxCeleryQueues:
CSV_GENERATION = "csv_generation"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
USER_FILES_INDEXING = "user_files_indexing"
# Document processing pipeline queue
DOCPROCESSING = "docprocessing"
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
# Monitoring queue
MONITORING = "monitoring"
@@ -453,7 +470,11 @@ class OnyxCeleryTask:
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
# New split indexing tasks
CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task"
DOCPROCESSING_TASK = "docprocessing_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"

View File

@@ -34,7 +34,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -281,30 +280,28 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# TODO: Refactor to avoid direct DB access in connector
# This will require broader refactoring across the codebase
with get_session_with_current_tenant() as db_session:
image_section, _ = store_image_and_create_section(
db_session=db_session,
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
image_section, _ = store_image_and_create_section(
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logger.exception(f"Error processing image {key}")
continue

View File

@@ -1,3 +1,17 @@
"""
# README (notes on Confluence pagination):
We've noticed that the `search/users` and `users/memberof` endpoints for Confluence Cloud use offset-based pagination as
opposed to cursor-based. We also know that page-retrieval uses cursor-based pagination.
Our default pagination strategy right now for cloud is to assume cursor-based.
However, if you notice that a cloud API is not being properly paginated (i.e., if the `_links.next` is not appearing in the
returned payload), then you can force offset-based pagination.
# TODO (@raunakab)
We haven't explored all of the cloud APIs' pagination strategies. @raunakab take time to go through this and figure them out.
"""
import json
import time
from collections.abc import Callable
@@ -46,16 +60,13 @@ _REPLACEMENT_EXPANSIONS = "body.view.value"
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
_DEFAULT_PAGINATION_LIMIT = 1000
class ConfluenceRateLimitError(Exception):
pass
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence:
"""
This is a custom Confluence class that:
@@ -311,8 +322,8 @@ class OnyxConfluence:
return confluence
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
# This uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling.
def _make_rate_limited_confluence_method(
self, name: str, credential_provider: CredentialsProviderInterface | None
) -> Callable[..., Any]:
@@ -378,25 +389,6 @@ class OnyxConfluence:
return wrapped_call
# def _wrap_methods(self) -> None:
# """
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
# wrap it with handle_confluence_rate_limit.
# """
# for attr_name in dir(self):
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
# setattr(
# self,
# attr_name,
# handle_confluence_rate_limit(getattr(self, attr_name)),
# )
# def _ensure_token_valid(self) -> None:
# if self._token_is_expired():
# self._refresh_token()
# # Re-init the Confluence client with the originally stored args
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
def __getattr__(self, name: str) -> Any:
"""Dynamically intercept attribute/method access."""
attr = getattr(self._confluence, name, None)
@@ -483,6 +475,7 @@ class OnyxConfluence:
limit: int | None = None,
# Called with the next url to use to get the next page
next_page_callback: Callable[[str], None] | None = None,
force_offset_pagination: bool = False,
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -498,6 +491,10 @@ class OnyxConfluence:
raw_response = self.get(
path=url_suffix,
advanced_mode=True,
params={
"body-format": "atlas_doc_format",
"expand": "body.atlas_doc_format",
},
)
except Exception as e:
logger.exception(f"Error in confluence call to {url_suffix}")
@@ -564,14 +561,32 @@ class OnyxConfluence:
)
raise e
# yield the results individually
# Yield the results individually.
results = cast(list[dict[str, Any]], next_response.get("results", []))
# make sure we don't update the start by more than the amount
# Note 1:
# Make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
#
# Note 2:
# We specifically perform manual yielding (i.e., `for x in xs: yield x`) as opposed to using a `yield from xs`
# because we *have to call the `next_page_callback`* prior to yielding the last element!
#
# If we did:
#
# ```py
# yield from results
# if next_page_callback:
# next_page_callback(url_suffix)
# ```
#
# then the logic would fail since the iterator would finish (and the calling scope would exit out of its driving
# loop) prior to the callback being called.
old_url_suffix = url_suffix
updated_start = get_start_param_from_url(old_url_suffix)
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
@@ -587,6 +602,12 @@ class OnyxConfluence:
)
# notify the caller of the new url
next_page_callback(url_suffix)
elif force_offset_pagination and i == len(results) - 1:
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
yield result
# we've observed that Confluence sometimes returns a next link despite giving
@@ -684,7 +705,9 @@ class OnyxConfluence:
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(url, limit):
for user_result in self._paginate_url(
url, limit, force_offset_pagination=True
):
# Example response:
# {
# 'user': {
@@ -774,7 +797,7 @@ class OnyxConfluence:
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
yield from self._paginate_url(url, limit, force_offset_pagination=True)
def paginated_groups_retrieval(
self,
@@ -926,6 +949,9 @@ def extract_text_from_confluence_html(
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
_remove_macro_stylings(soup=soup)
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
@@ -1007,3 +1033,15 @@ def extract_text_from_confluence_html(
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
def _remove_macro_stylings(soup: bs4.BeautifulSoup) -> None:
for macro_root in soup.findAll("ac:structured-macro"):
if not isinstance(macro_root, bs4.Tag):
continue
macro_styling = macro_root.find(name="ac:parameter", attrs={"ac:name": "page"})
if not macro_styling or not isinstance(macro_styling, bs4.Tag):
continue
macro_styling.extract()

View File

@@ -23,7 +23,6 @@ from onyx.configs.app_configs import (
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
@@ -224,19 +223,17 @@ def _process_image_attachment(
"""Process an image attachment by saving it without generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
section, file_name = store_image_and_create_section(
db_session=db_session,
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)

View File

@@ -109,8 +109,10 @@ def process_onyx_metadata(
return (
OnyxMetadata(
source_type=metadata.get("connector_type"),
link=metadata.get("link"),
file_display_name=metadata.get("file_display_name"),
title=metadata.get("title"),
primary_owners=p_owners,
secondary_owners=s_owners,
doc_updated_at=doc_updated_at,

View File

@@ -33,6 +33,7 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.highspot.connector import HighspotConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.imap.connector import ImapConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CredentialsConnector
@@ -121,6 +122,7 @@ def identify_connector_class(
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
DocumentSource.HIGHSPOT: HighspotConnector,
DocumentSource.IMAP: ImapConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}

View File

@@ -5,8 +5,6 @@ from pathlib import Path
from typing import Any
from typing import IO
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
@@ -18,7 +16,6 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -32,7 +29,6 @@ logger = setup_logger()
def _create_image_section(
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
@@ -58,7 +54,6 @@ def _create_image_section(
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_id=file_id,
display_name=display_name,
@@ -77,7 +72,6 @@ def _process_file(
file: IO[Any],
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
) -> list[Document]:
"""
Process a file and return a list of Documents.
@@ -125,7 +119,6 @@ def _process_file(
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=file_id,
display_name=title,
)
@@ -171,10 +164,12 @@ def _process_file(
custom_tags.update(more_custom_tags)
# File-specific metadata overrides metadata processed so far
source_type = onyx_metadata.source_type or source_type
primary_owners = onyx_metadata.primary_owners or primary_owners
secondary_owners = onyx_metadata.secondary_owners or secondary_owners
time_updated = onyx_metadata.doc_updated_at or time_updated
file_display_name = onyx_metadata.file_display_name or file_display_name
title = onyx_metadata.title or onyx_metadata.file_display_name or title
link = onyx_metadata.link or link
# Build sections: first the text as a single Section
@@ -194,7 +189,6 @@ def _process_file(
try:
image_section, stored_file_name = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=file_id,
display_name=f"{title} - image {idx}",
idx=idx,
@@ -258,37 +252,33 @@ class LocalFileConnector(LoadConnector):
"""
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
for file_id in self.file_locations:
file_store = get_default_file_store(db_session)
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(
f"No file record found for '{file_id}' in PG; skipping."
)
continue
for file_id in self.file_locations:
file_store = get_default_file_store()
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
continue
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
)
documents.extend(new_docs)
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
)
documents.extend(new_docs)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents
if __name__ == "__main__":
connector = LocalFileConnector(

View File

@@ -35,6 +35,7 @@ _FIREFLIES_API_QUERY = """
organizer_email
participants
date
duration
transcript_url
sentences {
text
@@ -101,7 +102,14 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},
metadata={
k: str(v)
for k, v in {
"meeting_date": meeting_date,
"duration_min": transcript.get("duration"),
}.items()
if v is not None
},
doc_updated_at=meeting_date,
primary_owners=organizer_email_user_info,
secondary_owners=meeting_participants_email_list,

View File

@@ -1,6 +1,10 @@
import copy
import json
import os
import sys
import threading
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
@@ -203,7 +207,9 @@ class GoogleDriveConnector(
specific_requests_made = False
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
specific_requests_made = True
self.specific_requests_made = specific_requests_made
# NOTE: potentially modified in load_credentials if using service account
self.include_files_shared_with_me = (
False if specific_requests_made else include_files_shared_with_me
)
@@ -284,6 +290,16 @@ class GoogleDriveConnector(
source=DocumentSource.GOOGLE_DRIVE,
)
# Service account connectors don't have a specific setting determining whether
# to include "shared with me" for each user, so we default to true unless the connector
# is in specific folders/drives mode. Note that shared files are only picked up during
# the My Drive stage, so this does nothing if the connector is set to only index shared drives.
if (
isinstance(self._creds, ServiceAccountCredentials)
and not self.specific_requests_made
):
self.include_files_shared_with_me = True
self._creds_dict = new_creds_dict
return new_creds_dict
@@ -1362,3 +1378,139 @@ class GoogleDriveConnector(
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint:
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
def get_credentials_from_env(email: str, oauth: bool) -> dict:
if oauth:
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
else:
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
refried_credential_string = json.dumps(json.loads(raw_credential_string))
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
cred_key = (
DB_CREDENTIALS_DICT_TOKEN_KEY
if oauth
else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
)
return {
cred_key: refried_credential_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: GoogleDriveCheckpoint | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> CheckpointOutput[GoogleDriveCheckpoint]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
def yield_all_docs_from_checkpoint_connector(
connector: GoogleDriveConnector,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> Iterator[Document | ConnectorFailure]:
num_iterations = 0
checkpoint = connector.build_dummy_checkpoint()
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
yield failure
if document is not None:
yield document
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > 100_000:
raise RuntimeError("Too many iterations. Infinite loop?")
if __name__ == "__main__":
import time
creds = get_credentials_from_env(
os.environ["GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL"], False
)
connector = GoogleDriveConnector(
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=True,
specific_user_emails=None,
)
connector.load_credentials(creds)
max_fsize = 0
biggest_fsize = 0
num_errors = 0
start_time = time.time()
with open("stats.txt", "w") as f:
for num, doc_or_failure in enumerate(
yield_all_docs_from_checkpoint_connector(connector, 0, time.time())
):
if num % 200 == 0:
f.write(f"Processed {num} files\n")
f.write(f"Max file size: {max_fsize/1000_000:.2f} MB\n")
f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n")
f.write(f"Docs per minute: {num/(time.time() - start_time)*60:.2f}\n")
biggest_fsize = max(biggest_fsize, max_fsize)
max_fsize = 0
if isinstance(doc_or_failure, Document):
max_fsize = max(max_fsize, sys.getsizeof(doc_or_failure))
elif isinstance(doc_or_failure, ConnectorFailure):
num_errors += 1
print(f"Num errors: {num_errors}")
print(f"Biggest file size: {biggest_fsize/1000_000:.2f} MB")
print(f"Time taken: {time.time() - start_time:.2f} seconds")

View File

@@ -3,6 +3,8 @@ from collections.abc import Callable
from datetime import datetime
from typing import Any
from typing import cast
from urllib.parse import urlparse
from urllib.parse import urlunparse
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import MediaIoBaseDownload # type: ignore
@@ -27,7 +29,6 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
@@ -77,7 +78,15 @@ class PermissionSyncContext(BaseModel):
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
return file[WEB_VIEW_LINK_KEY]
link = file[WEB_VIEW_LINK_KEY]
parsed_url = urlparse(link)
parsed_url = parsed_url._replace(query="") # remove query parameters
spl_path = parsed_url.path.split("/")
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
spl_path.pop()
parsed_url = parsed_url._replace(path="/".join(spl_path))
# Remove query parameters and reconstruct URL
return urlunparse(parsed_url)
def is_gdrive_image_mime_type(mime_type: str) -> bool:
@@ -119,9 +128,32 @@ def _download_and_extract_sections_basic(
mime_type = file["mimeType"]
link = file.get(WEB_VIEW_LINK_KEY, "")
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For non-Google files, download the file
# Use the correct API call for downloading files
# lazy evaluation to only download the file if necessary
def response_call() -> bytes:
return download_request(service, file_id)
if is_gdrive_image_mime_type(mime_type):
# Skip images if not explicitly enabled
if not allow_images:
return []
# Store images for later processing
sections: list[TextSection | ImageSection] = []
try:
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
@@ -144,12 +176,6 @@ def _download_and_extract_sections_basic(
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
# lazy evaluation to only download the file if necessary
def response_call() -> bytes:
return download_request(service, file_id)
# Process based on mime type
if mime_type == "text/plain":
try:
@@ -179,25 +205,6 @@ def _download_and_extract_sections_basic(
text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [
@@ -206,41 +213,30 @@ def _download_and_extract_sections_basic(
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
else:
# For unsupported file types, try to extract text
if mime_type in [
"application/vnd.google-apps.video",
"application/vnd.google-apps.audio",
"application/zip",
]:
return []
# Final attempt at extracting text
file_ext = get_file_ext(file.get("name", ""))
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
logger.warning(f"Skipping file {file.get('name')} due to extension.")
return []
# don't download the file at all if it's an unhandled extension
file_ext = get_file_ext(file.get("name", ""))
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
logger.warning(f"Skipping file {file.get('name')} due to extension.")
return []
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
try:
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
@@ -315,13 +311,17 @@ def align_basic_advanced(
def _get_external_access_for_raw_gdrive_file(
file: GoogleDriveFileType,
company_domain: str,
drive_service: GoogleDriveService,
retriever_drive_service: GoogleDriveService | None,
admin_drive_service: GoogleDriveService,
) -> ExternalAccess:
"""
Get the external access for a raw Google Drive file.
"""
external_access_fn = cast(
Callable[[GoogleDriveFileType, str, GoogleDriveService], ExternalAccess],
Callable[
[GoogleDriveFileType, str, GoogleDriveService | None, GoogleDriveService],
ExternalAccess,
],
fetch_versioned_implementation_with_fallback(
"onyx.external_permissions.google_drive.doc_sync",
"get_external_access_for_raw_gdrive_file",
@@ -331,7 +331,8 @@ def _get_external_access_for_raw_gdrive_file(
return external_access_fn(
file,
company_domain,
drive_service,
retriever_drive_service,
admin_drive_service,
)
@@ -436,6 +437,19 @@ def _convert_drive_item_to_document(
logger.info("Skipping shortcut/folder.")
return None
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logger.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logger.warning(
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
)
return None
# If it's a Google Doc, we might do advanced parsing
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
@@ -461,22 +475,8 @@ def _convert_drive_item_to_document(
logger.warning(
f"Error in advanced parsing: {e}. Falling back to basic extraction."
)
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logger.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logger.warning(
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
)
return None
# If we don't have sections yet, use the basic extraction method
if not sections:
# Not Google Doc, attempt basic extraction
else:
sections = _download_and_extract_sections_basic(
file, _get_drive_service(), allow_images
)
@@ -491,7 +491,9 @@ def _convert_drive_item_to_document(
_get_external_access_for_raw_gdrive_file(
file=file,
company_domain=permission_sync_context.google_domain,
drive_service=get_drive_service(
# try both retriever_email and primary_admin_email if necessary
retriever_drive_service=_get_drive_service(),
admin_drive_service=get_drive_service(
creds, user_email=permission_sync_context.primary_admin_email
),
)
@@ -557,14 +559,22 @@ def build_slim_document(
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
owner_email = file.get("owners", [{}])[0].get("emailAddress")
owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress"))
external_access = (
_get_external_access_for_raw_gdrive_file(
file=file,
company_domain=permission_sync_context.google_domain,
drive_service=get_drive_service(
retriever_drive_service=(
get_drive_service(
creds,
user_email=owner_email,
)
if owner_email
else None
),
admin_drive_service=get_drive_service(
creds,
user_email=owner_email or permission_sync_context.primary_admin_email,
user_email=permission_sync_context.primary_admin_email,
),
)
if permission_sync_context

View File

@@ -12,7 +12,6 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.html_utils import web_html_cleanup
@@ -68,10 +67,7 @@ class GoogleSitesConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)
file_content_io = get_default_file_store().read_file(self.zip_path, mode="b")
# load the HTML files
files = load_files_from_zip(file_content_io)

View File

@@ -0,0 +1,484 @@
import copy
import email
import imaplib
import os
import re
from datetime import datetime
from datetime import timezone
from email.message import Message
from email.utils import parseaddr
from enum import Enum
from typing import Any
from typing import cast
import bs4
from pydantic import BaseModel
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.imap.models import EmailHeaders
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
_DEFAULT_IMAP_PORT_NUMBER = int(os.environ.get("IMAP_PORT", 993))
_IMAP_OKAY_STATUS = "OK"
_PAGE_SIZE = 100
_USERNAME_KEY = "imap_username"
_PASSWORD_KEY = "imap_password"
class CurrentMailbox(BaseModel):
mailbox: str
todo_email_ids: list[str]
# An email has a list of mailboxes.
# Each mailbox has a list of email-ids inside of it.
#
# Usage:
# To use this checkpointer, first fetch all the mailboxes.
# Then, pop a mailbox and fetch all of its email-ids.
# Then, pop each email-id and fetch its content (and parse it, etc..).
# When you have popped all email-ids for this mailbox, pop the next mailbox and repeat the above process until you're done.
#
# For initial checkpointing, set both fields to `None`.
class ImapCheckpoint(ConnectorCheckpoint):
todo_mailboxes: list[str] | None = None
current_mailbox: CurrentMailbox | None = None
class LoginState(str, Enum):
LoggedIn = "logged_in"
LoggedOut = "logged_out"
class ImapConnector(
CredentialsConnector,
CheckpointedConnectorWithPermSync[ImapCheckpoint],
):
def __init__(
self,
host: str,
port: int = _DEFAULT_IMAP_PORT_NUMBER,
mailboxes: list[str] | None = None,
) -> None:
self._host = host
self._port = port
self._mailboxes = mailboxes
self._credentials: dict[str, Any] | None = None
@property
def credentials(self) -> dict[str, Any]:
if not self._credentials:
raise RuntimeError(
"Credentials have not been initialized; call `set_credentials_provider` first"
)
return self._credentials
def _get_mail_client(self) -> imaplib.IMAP4_SSL:
"""
Returns a new `imaplib.IMAP4_SSL` instance.
The `imaplib.IMAP4_SSL` object is supposed to be an "ephemeral" object; it's not something that you can login,
logout, then log back into again. I.e., the following will fail:
```py
mail_client.login(..)
mail_client.logout();
mail_client.login(..)
```
Therefore, you need a fresh, new instance in order to operate with IMAP. This function gives one to you.
# Notes
This function will throw an error if the credentials have not yet been set.
"""
def get_or_raise(name: str) -> str:
value = self.credentials.get(name)
if not value:
raise RuntimeError(f"Credential item {name=} was not found")
if not isinstance(value, str):
raise RuntimeError(
f"Credential item {name=} must be of type str, instead received {type(name)=}"
)
return value
username = get_or_raise(_USERNAME_KEY)
password = get_or_raise(_PASSWORD_KEY)
mail_client = imaplib.IMAP4_SSL(host=self._host, port=self._port)
status, _data = mail_client.login(user=username, password=password)
if status != _IMAP_OKAY_STATUS:
raise RuntimeError(f"Failed to log into imap server; {status=}")
return mail_client
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
include_perm_sync: bool,
) -> CheckpointOutput[ImapCheckpoint]:
checkpoint = cast(ImapCheckpoint, copy.deepcopy(checkpoint))
checkpoint.has_more = True
mail_client = self._get_mail_client()
if checkpoint.todo_mailboxes is None:
# This is the dummy checkpoint.
# Fill it with mailboxes first.
if self._mailboxes:
checkpoint.todo_mailboxes = _sanitize_mailbox_names(self._mailboxes)
else:
fetched_mailboxes = _fetch_all_mailboxes_for_email_account(
mail_client=mail_client
)
if not fetched_mailboxes:
raise RuntimeError(
"Failed to find any mailboxes for this email account"
)
checkpoint.todo_mailboxes = _sanitize_mailbox_names(fetched_mailboxes)
return checkpoint
if (
not checkpoint.current_mailbox
or not checkpoint.current_mailbox.todo_email_ids
):
if not checkpoint.todo_mailboxes:
checkpoint.has_more = False
return checkpoint
mailbox = checkpoint.todo_mailboxes.pop()
email_ids = _fetch_email_ids_in_mailbox(
mail_client=mail_client,
mailbox=mailbox,
start=start,
end=end,
)
checkpoint.current_mailbox = CurrentMailbox(
mailbox=mailbox,
todo_email_ids=email_ids,
)
_select_mailbox(
mail_client=mail_client, mailbox=checkpoint.current_mailbox.mailbox
)
current_todos = cast(
list, copy.deepcopy(checkpoint.current_mailbox.todo_email_ids[:_PAGE_SIZE])
)
checkpoint.current_mailbox.todo_email_ids = (
checkpoint.current_mailbox.todo_email_ids[_PAGE_SIZE:]
)
for email_id in current_todos:
email_msg = _fetch_email(mail_client=mail_client, email_id=email_id)
if not email_msg:
logger.warn(f"Failed to fetch message {email_id=}; skipping")
continue
email_headers = EmailHeaders.from_email_msg(email_msg=email_msg)
yield _convert_email_headers_and_body_into_document(
email_msg=email_msg,
email_headers=email_headers,
include_perm_sync=include_perm_sync,
)
return checkpoint
# impls for BaseConnector
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use `set_credentials_provider` instead")
def validate_connector_settings(self) -> None:
self._get_mail_client()
# impls for CredentialsConnector
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
self._credentials = credentials_provider.get_credentials()
# impls for CheckpointedConnector
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
) -> CheckpointOutput[ImapCheckpoint]:
return self._load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint, include_perm_sync=False
)
def build_dummy_checkpoint(self) -> ImapCheckpoint:
return ImapCheckpoint(has_more=True)
def validate_checkpoint_json(self, checkpoint_json: str) -> ImapCheckpoint:
return ImapCheckpoint.model_validate_json(json_data=checkpoint_json)
# impls for CheckpointedConnectorWithPermSync
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
) -> CheckpointOutput[ImapCheckpoint]:
return self._load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint, include_perm_sync=True
)
def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> list[str]:
status, mailboxes_data = mail_client.list(directory="*", pattern="*")
if status != _IMAP_OKAY_STATUS:
raise RuntimeError(f"Failed to fetch mailboxes; {status=}")
mailboxes = []
for mailboxes_raw in mailboxes_data:
if isinstance(mailboxes_raw, bytes):
mailboxes_str = mailboxes_raw.decode()
elif isinstance(mailboxes_raw, str):
mailboxes_str = mailboxes_raw
else:
logger.warn(
f"Expected the mailbox data to be of type str, instead got {type(mailboxes_raw)=} {mailboxes_raw}; skipping"
)
continue
# The mailbox LIST response output can be found here:
# https://www.rfc-editor.org/rfc/rfc3501.html#section-7.2.2
#
# The general format is:
# `(<name-attributes>) <hierarchy-delimiter> <mailbox-name>`
#
# The below regex matches on that pattern; from there, we select the 3rd match (index 2), which is the mailbox-name.
match = re.match(r'\([^)]*\)\s+"([^"]+)"\s+"?(.+?)"?$', mailboxes_str)
if not match:
logger.warn(
f"Invalid mailbox-data formatting structure: {mailboxes_str=}; skipping"
)
continue
mailbox = match.group(2)
mailboxes.append(mailbox)
return mailboxes
def _select_mailbox(mail_client: imaplib.IMAP4_SSL, mailbox: str) -> None:
status, _ids = mail_client.select(mailbox=mailbox, readonly=True)
if status != _IMAP_OKAY_STATUS:
raise RuntimeError(f"Failed to select {mailbox=}")
def _fetch_email_ids_in_mailbox(
mail_client: imaplib.IMAP4_SSL,
mailbox: str,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[str]:
_select_mailbox(mail_client=mail_client, mailbox=mailbox)
start_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime("%d-%b-%Y")
end_str = datetime.fromtimestamp(end, tz=timezone.utc).strftime("%d-%b-%Y")
search_criteria = f'(SINCE "{start_str}" BEFORE "{end_str}")'
status, email_ids_byte_array = mail_client.search(None, search_criteria)
if status != _IMAP_OKAY_STATUS or not email_ids_byte_array:
raise RuntimeError(f"Failed to fetch email ids; {status=}")
email_ids: bytes = email_ids_byte_array[0]
return [email_id.decode() for email_id in email_ids.split()]
def _fetch_email(mail_client: imaplib.IMAP4_SSL, email_id: str) -> Message | None:
status, msg_data = mail_client.fetch(message_set=email_id, message_parts="(RFC822)")
if status != _IMAP_OKAY_STATUS or not msg_data:
return None
data = msg_data[0]
if not isinstance(data, tuple):
raise RuntimeError(
f"Message data should be a tuple; instead got a {type(data)=} {data=}"
)
_metadata, raw_email = data
return email.message_from_bytes(raw_email)
def _convert_email_headers_and_body_into_document(
email_msg: Message,
email_headers: EmailHeaders,
include_perm_sync: bool,
) -> Document:
sender_name, sender_addr = _parse_singular_addr(raw_header=email_headers.sender)
parsed_recipients = (
_parse_addrs(raw_header=email_headers.recipients)
if email_headers.recipients
else []
)
expert_info_map = {
recipient_addr: BasicExpertInfo(
display_name=recipient_name, email=recipient_addr
)
for recipient_name, recipient_addr in parsed_recipients
}
if sender_addr not in expert_info_map:
expert_info_map[sender_addr] = BasicExpertInfo(
display_name=sender_name, email=sender_addr
)
email_body = _parse_email_body(email_msg=email_msg, email_headers=email_headers)
primary_owners = list(expert_info_map.values())
external_access = (
ExternalAccess(
external_user_emails=set(expert_info_map.keys()),
external_user_group_ids=set(),
is_public=False,
)
if include_perm_sync
else None
)
return Document(
id=email_headers.id,
title=email_headers.subject,
semantic_identifier=email_headers.subject,
metadata={},
source=DocumentSource.IMAP,
sections=[TextSection(text=email_body)],
primary_owners=primary_owners,
external_access=external_access,
)
def _parse_email_body(
email_msg: Message,
email_headers: EmailHeaders,
) -> str:
body = None
for part in email_msg.walk():
if part.is_multipart():
# Multipart parts are *containers* for other parts, not the actual content itself.
# Therefore, we skip until we find the individual parts instead.
continue
charset = part.get_content_charset() or "utf-8"
try:
raw_payload = part.get_payload(decode=True)
if not isinstance(raw_payload, bytes):
logger.warn(
"Payload section from email was expected to be an array of bytes, instead got "
f"{type(raw_payload)=}, {raw_payload=}"
)
continue
body = raw_payload.decode(charset)
break
except (UnicodeDecodeError, LookupError) as e:
print(f"Warning: Could not decode part with charset {charset}. Error: {e}")
continue
if not body:
logger.warn(
f"Email with {email_headers.id=} has an empty body; returning an empty string"
)
return ""
soup = bs4.BeautifulSoup(markup=body, features="html.parser")
return " ".join(str_section for str_section in soup.stripped_strings)
def _sanitize_mailbox_names(mailboxes: list[str]) -> list[str]:
"""
Mailboxes with special characters in them must be enclosed by double-quotes, as per the IMAP protocol.
Just to be safe, we wrap *all* mailboxes with double-quotes.
"""
return [f'"{mailbox}"' for mailbox in mailboxes if mailbox]
def _parse_addrs(raw_header: str) -> list[tuple[str, str]]:
addrs = raw_header.split(",")
name_addr_pairs = [parseaddr(addr=addr, strict=True) for addr in addrs if addr]
return [(name, addr) for name, addr in name_addr_pairs if addr]
def _parse_singular_addr(raw_header: str) -> tuple[str, str]:
addrs = _parse_addrs(raw_header=raw_header)
if not addrs:
raise RuntimeError(
f"Parsing email header resulted in no addresses being found; {raw_header=}"
)
elif len(addrs) >= 2:
raise RuntimeError(
f"Expected a singular address, but instead got multiple; {raw_header=} {addrs=}"
)
return addrs[0]
if __name__ == "__main__":
import time
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
host = os.environ.get("IMAP_HOST")
mailboxes_str = os.environ.get("IMAP_MAILBOXES")
username = os.environ.get("IMAP_USERNAME")
password = os.environ.get("IMAP_PASSWORD")
mailboxes = (
[mailbox.strip() for mailbox in mailboxes_str.split(",")]
if mailboxes_str
else []
)
if not host:
raise RuntimeError("`IMAP_HOST` must be set")
imap_connector = ImapConnector(
host=host,
mailboxes=mailboxes,
)
imap_connector.set_credentials_provider(
OnyxStaticCredentialsProvider(
tenant_id=None,
connector_name=DocumentSource.IMAP,
credential_json={
_USERNAME_KEY: username,
_PASSWORD_KEY: password,
},
)
)
for doc in load_all_docs_from_checkpoint_connector(
connector=imap_connector,
start=0,
end=time.time(),
):
print(doc)

View File

@@ -0,0 +1,75 @@
import email
from datetime import datetime
from email.message import Message
from enum import Enum
from pydantic import BaseModel
class Header(str, Enum):
SUBJECT_HEADER = "subject"
FROM_HEADER = "from"
TO_HEADER = "to"
DELIVERED_TO_HEADER = (
"Delivered-To" # Used in mailing lists instead of the "to" header.
)
DATE_HEADER = "date"
MESSAGE_ID_HEADER = "Message-ID"
class EmailHeaders(BaseModel):
"""
Model for email headers extracted from IMAP messages.
"""
id: str
subject: str
sender: str
recipients: str | None
date: datetime
@classmethod
def from_email_msg(cls, email_msg: Message) -> "EmailHeaders":
def _decode(header: str, default: str | None = None) -> str | None:
value = email_msg.get(header, default)
if not value:
return None
decoded_value, _encoding = email.header.decode_header(value)[0]
if isinstance(decoded_value, bytes):
return decoded_value.decode()
elif isinstance(decoded_value, str):
return decoded_value
else:
return None
def _parse_date(date_str: str | None) -> datetime | None:
if not date_str:
return None
try:
return email.utils.parsedate_to_datetime(date_str)
except (TypeError, ValueError):
return None
message_id = _decode(header=Header.MESSAGE_ID_HEADER)
# It's possible for the subject line to not exist or be an empty string.
subject = _decode(header=Header.SUBJECT_HEADER) or "Unknown Subject"
from_ = _decode(header=Header.FROM_HEADER)
to = _decode(header=Header.TO_HEADER)
if not to:
to = _decode(header=Header.DELIVERED_TO_HEADER)
date_str = _decode(header=Header.DATE_HEADER)
date = _parse_date(date_str=date_str)
# If any of the above are `None`, model validation will fail.
# Therefore, no guards (i.e.: `if <header> is None: raise RuntimeError(..)`) were written.
return cls.model_validate(
{
"id": message_id,
"subject": subject,
"sender": from_,
"recipients": to,
"date": date,
}
)

View File

@@ -129,7 +129,14 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
pywikibot.Category(
self.site,
(
f"{category.replace(' ', '_')}"
if category.startswith("Category:")
else f"Category:{category.replace(' ', '_')}"
),
)
for category in categories
]

View File

@@ -11,6 +11,7 @@ from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.db.enums import IndexModelStatus
from onyx.utils.text_processing import make_url_compatible
@@ -363,9 +364,38 @@ class ConnectorFailure(BaseModel):
return values
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class OnyxMetadata(BaseModel):
# Note that doc_id cannot be overriden here as it may cause issues
# with the display functionalities in the UI. Ask @chris if clarification is needed.
source_type: DocumentSource | None = None
link: str | None = None
file_display_name: str | None = None
primary_owners: list[BasicExpertInfo] | None = None
secondary_owners: list[BasicExpertInfo] | None = None
doc_updated_at: datetime | None = None
title: str | None = None
class DocExtractionContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
doc_extraction_complete_batch_num: int | None
class DocIndexingContext(BaseModel):
batches_done: int
total_failures: int
net_doc_change: int
total_chunks: int

View File

@@ -267,7 +267,7 @@ class NotionConnector(LoadConnector, PollConnector):
result = ""
for prop_name, prop in properties.items():
if not prop:
if not prop or not isinstance(prop, dict):
continue
try:

View File

@@ -992,7 +992,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
query_result = self.sf_client.safe_query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",

View File

@@ -1,18 +1,31 @@
import time
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from simple_salesforce.exceptions import SalesforceRefusedRequest
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_OBJECTS
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_PREFIXES
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_SUFFIXES
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
"""Check if an exception is a Salesforce rate limit error."""
return isinstance(
exception, SalesforceRefusedRequest
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
class OnyxSalesforce(Salesforce):
SOQL_MAX_SUBQUERIES = 20
@@ -52,6 +65,48 @@ class OnyxSalesforce(Salesforce):
return False
@retry_builder(
tries=5,
delay=20,
backoff=1.5,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def safe_query(self, query: str, **kwargs: Any) -> dict[str, Any]:
"""Wrapper around the original query method with retry logic and rate limiting."""
try:
return super().query(query, **kwargs)
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for query: {query[:100]}..."
)
# Add additional delay for rate limit errors
time.sleep(5)
raise
@retry_builder(
tries=5,
delay=20,
backoff=1.5,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def safe_query_all(self, query: str, **kwargs: Any) -> dict[str, Any]:
"""Wrapper around the original query_all method with retry logic and rate limiting."""
try:
return super().query_all(query, **kwargs)
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for query_all: {query[:100]}..."
)
# Add additional delay for rate limit errors
time.sleep(5)
raise
@staticmethod
def _make_child_objects_by_id_query(
object_id: str,
@@ -99,7 +154,7 @@ class OnyxSalesforce(Salesforce):
queryable_fields = type_to_queryable_fields[object_type]
query = get_object_by_id_query(object_id, object_type, queryable_fields)
result = self.query(query)
result = self.safe_query(query)
if not result:
return None
@@ -151,7 +206,7 @@ class OnyxSalesforce(Salesforce):
)
try:
result = self.query(query)
result = self.safe_query(query)
except Exception:
logger.exception(f"Query failed: {query=}")
else:
@@ -189,10 +244,25 @@ class OnyxSalesforce(Salesforce):
return child_records
@retry_builder(
tries=3,
delay=1,
backoff=2,
exceptions=(SalesforceRefusedRequest,),
)
def describe_type(self, name: str) -> Any:
sf_object = SFType(name, self.session_id, self.sf_instance)
result = sf_object.describe()
return result
try:
result = sf_object.describe()
return result
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for describe_type: {name}"
)
# Add additional delay for rate limit errors
time.sleep(3)
raise
def get_queryable_fields_by_type(self, name: str) -> list[str]:
object_description = self.describe_type(name)

View File

@@ -1,5 +1,6 @@
import gc
import os
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
@@ -7,13 +8,25 @@ from pytz import UTC
from simple_salesforce import Salesforce
from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from simple_salesforce.exceptions import SalesforceRefusedRequest
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
"""Check if an exception is a Salesforce rate limit error."""
return isinstance(
exception, SalesforceRefusedRequest
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
def _build_last_modified_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
@@ -71,6 +84,14 @@ def get_object_by_id_query(
return query
@retry_builder(
tries=5,
delay=2,
backoff=2,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def _object_type_has_api_data(
sf_client: Salesforce, sf_type: str, time_filter: str
) -> bool:
@@ -82,6 +103,15 @@ def _object_type_has_api_data(
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for object type check: {sf_type}"
)
# Add additional delay for rate limit errors
time.sleep(3)
raise
except Exception as e:
if "OPERATION_TOO_LARGE" not in str(e):
logger.warning(f"Object type {sf_type} doesn't support query: {e}")

View File

@@ -1,5 +1,6 @@
import io
import os
import time
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
@@ -11,9 +12,11 @@ from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from office365.onedrive.sites.site import Site # type: ignore
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore
from office365.runtime.client_request import ClientRequestException # type: ignore
from pydantic import BaseModel
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import SHAREPOINT_CONNECTOR_SIZE_THRESHOLD
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
@@ -46,12 +49,72 @@ class SiteDescriptor(BaseModel):
folder_path: str | None
def _sleep_and_retry(query_obj: Any, method_name: str, max_retries: int = 3) -> Any:
"""
Execute a SharePoint query with retry logic for rate limiting.
"""
for attempt in range(max_retries + 1):
try:
return query_obj.execute_query()
except ClientRequestException as e:
if (
e.response
and e.response.status_code in [429, 503]
and attempt < max_retries
):
logger.warning(
f"Rate limit exceeded on {method_name}, attempt {attempt + 1}/{max_retries + 1}, sleeping and retrying"
)
retry_after = e.response.headers.get("Retry-After")
if retry_after:
sleep_time = int(retry_after)
else:
# Exponential backoff: 2^attempt * 5 seconds
sleep_time = min(30, (2**attempt) * 5)
logger.info(f"Sleeping for {sleep_time} seconds before retry")
time.sleep(sleep_time)
else:
# Either not a rate limit error, or we've exhausted retries
if e.response and e.response.status_code == 429:
logger.error(
f"Rate limit retry exhausted for {method_name} after {max_retries} attempts"
)
raise e
def _convert_driveitem_to_document(
driveitem: DriveItem,
drive_name: str,
) -> Document:
) -> Document | None:
# Check file size before downloading
try:
size_value = getattr(driveitem, "size", None)
if size_value is not None:
file_size = int(size_value)
if file_size > SHAREPOINT_CONNECTOR_SIZE_THRESHOLD:
logger.warning(
f"File '{driveitem.name}' exceeds size threshold of {SHAREPOINT_CONNECTOR_SIZE_THRESHOLD} bytes. "
f"File size: {file_size} bytes. Skipping."
)
return None
else:
logger.warning(
f"Could not access file size for '{driveitem.name}' Proceeding with download."
)
except (ValueError, TypeError, AttributeError) as e:
logger.info(
f"Could not access file size for '{driveitem.name}': {e}. Proceeding with download."
)
# Proceed with download if size is acceptable or not available
content = _sleep_and_retry(driveitem.get_content(), "get_content")
if content is None:
logger.warning(f"Could not access content for '{driveitem.name}'")
return None
file_text = extract_file_text(
file=io.BytesIO(driveitem.get_content().execute_query().value),
file=io.BytesIO(content.value),
file_name=driveitem.name,
break_on_unprocessable=False,
)
@@ -275,7 +338,11 @@ class SharepointConnector(LoadConnector, PollConnector):
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
for driveitem, drive_name in driveitems:
logger.debug(f"Processing: {driveitem.web_url}")
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
# Convert driveitem to document with size checking
doc = _convert_driveitem_to_document(driveitem, drive_name)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch

View File

@@ -286,9 +286,14 @@ class TeamsConnector(
def _construct_semantic_identifier(channel: Channel, top_message: Message) -> str:
top_message_user_name = (
top_message.from_.user.display_name if top_message.from_ else "Unknown User"
)
top_message_user_name: str
if top_message.from_ and top_message.from_.user:
top_message_user_name = top_message.from_.user.display_name
else:
logger.warn(f"Message {top_message=} has no `from.user` field")
top_message_user_name = "Unknown User"
top_message_content = top_message.body.content or ""
top_message_subject = top_message.subject or "Unknown Subject"
channel_name = channel.properties.get("displayName", "Unknown")

View File

@@ -27,7 +27,7 @@ class User(BaseModel):
class From(BaseModel):
user: User
user: User | None
model_config = ConfigDict(
alias_generator=to_camel,

View File

@@ -23,6 +23,7 @@ class OptionalSearchSetting(str, Enum):
class SearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
INTERNET = "internet"
class LLMEvaluationType(str, Enum):

View File

@@ -0,0 +1,18 @@
from datetime import datetime
from pydantic import BaseModel
class SlackMessage(BaseModel):
document_id: str
channel_id: str
message_id: str
thread_id: str | None
link: str
metadata: dict[str, str | list[str]]
timestamp: datetime
recency_bias: float
semantic_identifier: str
text: str
highlighted_texts: set[str]
slack_score: float

View File

@@ -0,0 +1,416 @@
import re
from datetime import datetime
from datetime import timedelta
from typing import Any
from langchain_core.messages import HumanMessage
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import SearchQuery
from onyx.db.document import DocumentSource
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
)
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.llm.factory import get_default_llms
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
logger = setup_logger()
HIGHLIGHT_START_CHAR = "\ue000"
HIGHLIGHT_END_CHAR = "\ue001"
def build_slack_queries(query: SearchQuery, llm: LLM) -> list[str]:
# get time filter
time_filter = ""
time_cutoff = query.filters.time_cutoff
if time_cutoff is not None:
# slack after: is exclusive, so we need to subtract one day
time_cutoff = time_cutoff - timedelta(days=1)
time_filter = f" after:{time_cutoff.strftime('%Y-%m-%d')}"
# use llm to generate slack queries (use original query to use same keywords as the user)
prompt = SLACK_QUERY_EXPANSION_PROMPT.format(query=query.original_query)
try:
msg = HumanMessage(content=prompt)
response = llm.invoke([msg])
rephrased_queries = message_to_string(response).split("\n")
except Exception as e:
logger.error(f"Error expanding query: {e}")
rephrased_queries = [query.query]
return [
rephrased_query.strip() + time_filter
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
]
def query_slack(
query_string: str,
original_query: SearchQuery,
access_token: str,
limit: int | None = None,
) -> list[SlackMessage]:
# query slack
slack_client = WebClient(token=access_token)
try:
response = slack_client.search_messages(
query=query_string, count=limit, highlight=True
)
response.validate()
messages: dict[str, Any] = response.get("messages", {})
matches: list[dict[str, Any]] = messages.get("matches", [])
except SlackApiError as e:
logger.error(f"Slack API error in query_slack: {e}")
return []
# convert matches to slack messages
slack_messages: list[SlackMessage] = []
for match in matches:
text: str | None = match.get("text")
permalink: str | None = match.get("permalink")
message_id: str | None = match.get("ts")
channel_id: str | None = match.get("channel", {}).get("id")
channel_name: str | None = match.get("channel", {}).get("name")
username: str | None = match.get("username")
score: float = match.get("score", 0.0)
if ( # can't use any() because of type checking :(
not text
or not permalink
or not message_id
or not channel_id
or not channel_name
or not username
):
continue
# generate thread id and document id
thread_id = (
permalink.split("?thread_ts=", 1)[1] if "?thread_ts=" in permalink else None
)
document_id = f"{channel_id}_{message_id}"
# compute recency bias (parallels vespa calculation) and metadata
decay_factor = DOC_TIME_DECAY * original_query.recency_bias_multiplier
doc_time = datetime.fromtimestamp(float(message_id))
doc_age_years = (datetime.now() - doc_time).total_seconds() / (
365 * 24 * 60 * 60
)
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
metadata: dict[str, str | list[str]] = {
"channel": channel_name,
"time": doc_time.isoformat(),
}
# extract out the highlighted texts
highlighted_texts = set(
re.findall(
rf"{re.escape(HIGHLIGHT_START_CHAR)}(.*?){re.escape(HIGHLIGHT_END_CHAR)}",
text,
)
)
cleaned_text = text.replace(HIGHLIGHT_START_CHAR, "").replace(
HIGHLIGHT_END_CHAR, ""
)
# get the semantic identifier
snippet = (
cleaned_text[:50].rstrip() + "..." if len(cleaned_text) > 50 else text
).replace("\n", " ")
doc_sem_id = f"{username} in #{channel_name}: {snippet}"
slack_messages.append(
SlackMessage(
document_id=document_id,
channel_id=channel_id,
message_id=message_id,
thread_id=thread_id,
link=permalink,
metadata=metadata,
timestamp=doc_time,
recency_bias=recency_bias,
semantic_identifier=doc_sem_id,
text=f"{username}: {cleaned_text}",
highlighted_texts=highlighted_texts,
slack_score=score,
)
)
return slack_messages
def merge_slack_messages(
slack_messages: list[list[SlackMessage]],
) -> tuple[list[SlackMessage], dict[str, SlackMessage]]:
merged_messages: list[SlackMessage] = []
docid_to_message: dict[str, SlackMessage] = {}
for messages in slack_messages:
for message in messages:
if message.document_id in docid_to_message:
# update the score and highlighted texts, rest should be identical
docid_to_message[message.document_id].slack_score = max(
docid_to_message[message.document_id].slack_score,
message.slack_score,
)
docid_to_message[message.document_id].highlighted_texts.update(
message.highlighted_texts
)
continue
# add the message to the list
docid_to_message[message.document_id] = message
merged_messages.append(message)
# re-sort by score
merged_messages.sort(key=lambda x: x.slack_score, reverse=True)
return merged_messages, docid_to_message
def get_contextualized_thread_text(message: SlackMessage, access_token: str) -> str:
"""
Retrieves the initial thread message as well as the text following the message
and combines them into a single string. If the slack query fails, returns the
original message text.
The idea is that the message (the one that actually matched the search), the
initial thread message, and the replies to the message are important in answering
the user's query.
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# if it's not a thread, return the message text
if thread_id is None:
return message.text
# get the thread messages
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
ts=thread_id,
)
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.error(f"Slack API error in get_contextualized_thread_text: {e}")
return message.text
# make sure we didn't get an empty response or a single message (not a thread)
if len(messages) <= 1:
return message.text
# add the initial thread message
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# add the message (unless it's the initial message)
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
else:
message_id_idx = next(
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
)
if not message_id_idx:
return thread_text
# add the message
thread_text += "\n..." if message_id_idx > 1 else ""
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n<@{msg_sender}>: {msg_text}"
# add the following replies to the thread text
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
msg_text = msg.get("text", "")
msg_sender = msg.get("user", "")
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
# stop if len_replies exceeds chunk_size * 4 chars as the rest likely won't fit
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
# replace user ids with names in the thread text
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
for userid in userids:
try:
response = slack_client.users_profile_get(user=userid)
response.validate()
profile: dict[str, Any] = response.get("profile", {})
name: str | None = profile.get("real_name") or profile.get("email")
except SlackApiError as e:
logger.error(f"Slack API error in get_contextualized_thread_text: {e}")
continue
if not name:
continue
thread_text = thread_text.replace(f"<@{userid}>", name)
return thread_text
def convert_slack_score(slack_score: float) -> float:
"""
Convert slack score to a score between 0 and 1.
Will affect UI ordering and LLM ordering, but not the pruning.
I.e., should have very little effect on the search/answer quality.
"""
return max(0.0, min(1.0, slack_score / 90_000))
@log_function_time(print_only=True)
def slack_retrieval(
query: SearchQuery,
access_token: str,
db_session: Session,
limit: int | None = None,
) -> list[InferenceChunk]:
# query slack
_, fast_llm = get_default_llms()
query_strings = build_slack_queries(query, fast_llm)
results: list[list[SlackMessage]] = run_functions_tuples_in_parallel(
[
(query_slack, (query_string, query, access_token, limit))
for query_string in query_strings
]
)
slack_messages, docid_to_message = merge_slack_messages(results)
slack_messages = slack_messages[: limit or len(slack_messages)]
if not slack_messages:
return []
# contextualize the slack messages
thread_texts: list[str] = run_functions_tuples_in_parallel(
[
(get_contextualized_thread_text, (slack_message, access_token))
for slack_message in slack_messages
]
)
for slack_message, thread_text in zip(slack_messages, thread_texts):
slack_message.text = thread_text
# get the highlighted texts from shortest to longest
highlighted_texts: set[str] = set()
for slack_message in slack_messages:
highlighted_texts.update(slack_message.highlighted_texts)
sorted_highlighted_texts = sorted(highlighted_texts, key=len)
# convert slack messages to index documents
index_docs: list[IndexingDocument] = []
for slack_message in slack_messages:
section: TextSection = TextSection(
text=slack_message.text, link=slack_message.link
)
index_docs.append(
IndexingDocument(
id=slack_message.document_id,
sections=[section],
processed_sections=[section],
source=DocumentSource.SLACK,
title=slack_message.semantic_identifier,
semantic_identifier=slack_message.semantic_identifier,
metadata=slack_message.metadata,
doc_updated_at=slack_message.timestamp,
)
)
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)
multipass_config = get_multipass_config(search_settings)
enable_contextual_rag = (
search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG
)
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass_config.multipass_indexing,
enable_large_chunks=multipass_config.enable_large_chunks,
enable_contextual_rag=enable_contextual_rag,
)
chunks = chunker.chunk(index_docs)
# prune chunks without any highlighted texts
relevant_chunks: list[DocAwareChunk] = []
chunkid_to_match_highlight: dict[str, str] = {}
for chunk in chunks:
match_highlight = chunk.content
for highlight in sorted_highlighted_texts: # faster than re sub
match_highlight = match_highlight.replace(
highlight, f"<hi>{highlight}</hi>"
)
# if nothing got replaced, the chunk is irrelevant
if len(match_highlight) == len(chunk.content):
continue
chunk_id = f"{chunk.source_document.id}__{chunk.chunk_id}"
relevant_chunks.append(chunk)
chunkid_to_match_highlight[chunk_id] = match_highlight
if limit and len(relevant_chunks) >= limit:
break
# convert to inference chunks
top_chunks: list[InferenceChunk] = []
for chunk in relevant_chunks:
document_id = chunk.source_document.id
chunk_id = f"{document_id}__{chunk.chunk_id}"
top_chunks.append(
InferenceChunk(
chunk_id=chunk.chunk_id,
blurb=chunk.blurb,
content=chunk.content,
source_links=chunk.source_links,
image_file_id=chunk.image_file_id,
section_continuation=chunk.section_continuation,
semantic_identifier=docid_to_message[document_id].semantic_identifier,
document_id=document_id,
source_type=DocumentSource.SLACK,
title=chunk.title_prefix,
boost=0,
recency_bias=docid_to_message[document_id].recency_bias,
score=convert_slack_score(docid_to_message[document_id].slack_score),
hidden=False,
is_relevant=None,
relevance_explanation="",
metadata=docid_to_message[document_id].metadata,
match_highlights=[chunkid_to_match_highlight[chunk_id]],
doc_summary="",
chunk_context="",
updated_at=docid_to_message[document_id].timestamp,
is_federated=True,
)
)
return top_chunks

View File

@@ -154,6 +154,7 @@ class SearchRequest(ChunkContext):
query: str
expanded_queries: QueryExpansions | None = None
original_query: str | None = None
search_type: SearchType = SearchType.SEMANTIC
@@ -205,6 +206,7 @@ class SearchQuery(ChunkContext):
precomputed_query_embedding: Embedding | None = None
expanded_queries: QueryExpansions | None = None
original_query: str | None
class RetrievalDetails(ChunkContext):
@@ -252,6 +254,8 @@ class InferenceChunk(BaseChunk):
secondary_owners: list[str] | None = None
large_chunk_reference_ids: list[int] = Field(default_factory=list)
is_federated: bool = False
@property
def unique_id(self) -> str:
return f"{self.document_id}__{self.chunk_id}"

View File

@@ -158,6 +158,7 @@ class SearchPipeline:
# These chunks do not include large chunks and have been deduped
self._retrieved_chunks = retrieve_chunks(
query=self.search_query,
user_id=self.user.id if self.user else None,
document_index=self.document_index,
db_session=self.db_session,
retrieval_metrics_callback=self.retrieval_metrics_callback,

View File

@@ -25,7 +25,6 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
@@ -70,18 +69,16 @@ def update_image_sections_with_query(
logger.debug(
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_id}"
)
with get_session_with_current_tenant() as db_session:
file_record = get_default_file_store(db_session).read_file(
cast(str, chunk.image_file_id), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_id}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(
f"Successfully loaded image data for {chunk.image_file_id}"
)
file_record = get_default_file_store().read_file(
cast(str, chunk.image_file_id), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_id}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(f"Successfully loaded image data for {chunk.image_file_id}")
messages: list[BaseMessage] = [
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),

View File

@@ -250,6 +250,7 @@ def retrieval_preprocessing(
return SearchQuery(
query=query,
original_query=search_request.original_query,
processed_keywords=processed_keywords,
search_type=SearchType.KEYWORD if is_keyword else SearchType.SEMANTIC,
evaluation_type=llm_evaluation_type,

View File

@@ -1,5 +1,6 @@
import string
from collections.abc import Callable
from uuid import UUID
import nltk # type:ignore
from sqlalchemy.orm import Session
@@ -17,15 +18,18 @@ from onyx.context.search.models import SearchQuery
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA_KEYWORD
from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import get_query_embeddings
from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_multilingual_expansion
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -33,9 +37,6 @@ from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -115,34 +116,6 @@ def combine_retrieval_results(
return sorted_chunks
def get_query_embedding(query: str, db_session: Session) -> Embedding:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
return query_embedding
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -350,6 +323,7 @@ def _simplify_text(text: str) -> str:
def retrieve_chunks(
query: SearchQuery,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session,
retrieval_metrics_callback: (
@@ -359,14 +333,34 @@ def retrieve_chunks(
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
multilingual_expansion = get_multilingual_expansion(db_session)
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, db_session=db_session
)
else:
run_queries: list[tuple[Callable, tuple]] = []
source_filters = (
set(query.filters.source_type) if query.filters.source_type else None
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session, user_id, query.filters.source_type, query.filters.document_set
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
for federated_retrieval_info in federated_retrieval_infos
)
for federated_retrieval_info in federated_retrieval_infos:
run_queries.append((federated_retrieval_info.retrieval_function, (query,)))
# Normal retrieval
normal_search_enabled = (source_filters is None) or (
len(set(source_filters) - federated_sources) > 0
)
if normal_search_enabled and (
not multilingual_expansion or "\n" in query.query or "\r" in query.query
):
# Don't do query expansion on complex queries, rephrasings likely would not work well
run_queries.append((doc_index_retrieval, (query, document_index, db_session)))
elif normal_search_enabled:
simplified_queries = set()
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
@@ -393,13 +387,11 @@ def retrieve_chunks(
deep=True,
)
run_queries.append(
(
doc_index_retrieval,
(q_copy, document_index, db_session),
)
(doc_index_retrieval, (q_copy, document_index, db_session))
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
if not top_chunks:
logger.warning(

View File

@@ -4,6 +4,7 @@ from typing import TypeVar
from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceChunk
@@ -12,7 +13,13 @@ from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SavedSearchDocWithContent
from onyx.context.search.models import SearchDoc
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.search_settings import get_current_search_settings
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -160,3 +167,21 @@ def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
except Exception as e:
logger.warning(f"Error removing stop words and punctuation: {e}")
return keywords
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]

Some files were not shown because too many files have changed in this diff Show More